ソースを参照

fix: initialize embedding function manually (#205)

cachho 2 年 前
コミット
35b6e14328
2 ファイル変更48 行追加10 行削除
  1. 31 0
      embedchain/config/InitConfig.py
  2. 17 10
      embedchain/embedchain.py

+ 31 - 0
embedchain/config/InitConfig.py

@@ -1,6 +1,8 @@
 import logging
 import os
 
+from chromadb.utils import embedding_functions
+
 from embedchain.config.BaseConfig import BaseConfig
 
 
@@ -37,11 +39,40 @@ class InitConfig(BaseConfig):
         else:
             self.db = db
 
+        self.ef = ef
+        self.db = db
         return
 
     def _set_embedding_function(self, ef):
         self.ef = ef
         return
+    
+    def _set_embedding_function_to_default(self):
+        """
+        Sets embedding function to default (`text-embedding-ada-002`).
+
+        :raises ValueError: If the template is not valid as template should contain $context and $query
+        """
+        if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_ORGANIZATION") is None:
+            raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided")
+        self.ef = embedding_functions.OpenAIEmbeddingFunction(
+                api_key=os.getenv("OPENAI_API_KEY"),
+                organization_id=os.getenv("OPENAI_ORGANIZATION"),
+                model_name="text-embedding-ada-002"
+            )
+        return
+    
+    def _set_db(self, db):
+        if db:
+            self.db = db            
+        return
+
+    def _set_db_to_default(self):
+        """
+        Sets database to default (`ChromaDb`).
+        """
+        from embedchain.vectordb.chroma_db import ChromaDB
+        self.db = ChromaDB(ef=self.ef)
 
     def _setup_logging(self, debug_level):
         level = logging.WARNING  # Default level

+ 17 - 10
embedchain/embedchain.py

@@ -297,6 +297,13 @@ class App(EmbedChain):
         """
         if config is None:
             config = InitConfig()
+        
+        if not config.ef:
+            config._set_embedding_function_to_default()
+
+        if not config.db:
+            config._set_db_to_default()
+        
         super().__init__(config)
 
     def get_llm_model_answer(self, prompt, config: ChatConfig):
@@ -345,17 +352,17 @@ class OpenSourceApp(EmbedChain):
             "Loading open source embedding model. This may take some time..."
         )  # noqa:E501
         if not config:
-            config = InitConfig(
-                ef=embedding_functions.SentenceTransformerEmbeddingFunction(
-                    model_name="all-MiniLM-L6-v2"
-                )
-            )
-        elif not config.ef:
+            config = InitConfig()
+        
+        if not config.ef:
             config._set_embedding_function(
-                embedding_functions.SentenceTransformerEmbeddingFunction(
-                    model_name="all-MiniLM-L6-v2"
-                )
-            )
+                    embedding_functions.SentenceTransformerEmbeddingFunction(
+                model_name="all-MiniLM-L6-v2"
+            ))
+
+        if not config.db:
+            config._set_db_to_default()
+
         print("Successfully loaded open source embedding model.")
         super().__init__(config)