ソースを参照

Raise import error if Ollama and Google not found (#1432)

Dev Khant 1 年間 前
コミット
e3e107b31d
2 ファイル変更9 行追加11 行削除
  1. 4 10
      embedchain/llm/google.py
  2. 5 1
      embedchain/llm/ollama.py

+ 4 - 10
embedchain/llm/google.py

@@ -1,10 +1,12 @@
-import importlib
 import logging
 import os
 from collections.abc import Generator
 from typing import Any, Optional, Union
 
-import google.generativeai as genai
+try:
+    import google.generativeai as genai
+except ImportError:
+    raise ImportError("GoogleLlm requires extra dependencies. Install with `pip install google-generativeai`") from None
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -16,14 +18,6 @@ logger = logging.getLogger(__name__)
 @register_deserializable
 class GoogleLlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
-        try:
-            importlib.import_module("google.generativeai")
-        except ModuleNotFoundError:
-            raise ModuleNotFoundError(
-                "The required dependencies for GoogleLlm are not installed."
-                'Please install with `pip install --upgrade "embedchain[google]"`'
-            ) from None
-
         super().__init__(config)
         if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ:
             raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.")

+ 5 - 1
embedchain/llm/ollama.py

@@ -6,7 +6,11 @@ from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.stdout import StdOutCallbackHandler
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain_community.llms.ollama import Ollama
-from ollama import Client
+
+try:
+    from ollama import Client
+except ImportError:
+    raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable