浏览代码

[Bug fix] import App shouldn't throw other llm deps errors (#837)

Sidharth Mohanty 1 年之前
父节点
当前提交
a27eeb3255
共有 2 个文件被更改,包括 14 次插入16 次删除
  1. 7 8
      embedchain/llm/llama2.py
  2. 7 8
      embedchain/llm/vertex_ai.py

+ 7 - 8
embedchain/llm/llama2.py

@@ -8,18 +8,17 @@ from embedchain.config import BaseLlmConfig
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
-try:
-    importlib.import_module("replicate")
-except ModuleNotFoundError:
-    raise ModuleNotFoundError(
-        "The required dependencies for Llama2 are not installed."
-        'Please install with `pip install --upgrade "embedchain[llama2]"`'
-    ) from None
-
 
 @register_deserializable
 class Llama2Llm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
+        try:
+            importlib.import_module("replicate")
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "The required dependencies for Llama2 are not installed."
+                'Please install with `pip install --upgrade "embedchain[llama2]"`'
+            ) from None
         if "REPLICATE_API_TOKEN" not in os.environ:
             raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
 

+ 7 - 8
embedchain/llm/vertex_ai.py

@@ -6,18 +6,17 @@ from embedchain.config import BaseLlmConfig
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
-try:
-    importlib.import_module("vertexai")
-except ModuleNotFoundError:
-    raise ModuleNotFoundError(
-        "The required dependencies for VertexAI are not installed."
-        'Please install with `pip install --upgrade "embedchain[vertexai]"`'
-    ) from None
-
 
 @register_deserializable
 class VertexAILlm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
+        try:
+            importlib.import_module("vertexai")
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "The required dependencies for VertexAI are not installed."
+                'Please install with `pip install --upgrade "embedchain[vertexai]"`'
+            ) from None
         super().__init__(config=config)
 
     def get_llm_model_answer(self, prompt):