浏览代码

fix: Don't create instance in InitConfig `__init__` (#236)

cachho 2 年之前
父节点
当前提交
6fbf45498a
共有 2 个文件被更改,包括 6 次插入21 次删除
  1. 4 21
      embedchain/config/InitConfig.py
  2. 2 0
      embedchain/vectordb/chroma_db.py

+ 4 - 21
embedchain/config/InitConfig.py

@@ -20,27 +20,11 @@ class InitConfig(BaseConfig):
         """
         """
         self._setup_logging(log_level)
         self._setup_logging(log_level)
 
 
-        # Embedding Function
-        if ef is None:
-            from chromadb.utils import embedding_functions
-
-            self.ef = embedding_functions.OpenAIEmbeddingFunction(
-                api_key=os.getenv("OPENAI_API_KEY"),
-                organization_id=os.getenv("OPENAI_ORGANIZATION"),
-                model_name="text-embedding-ada-002",
-            )
-        else:
-            self.ef = ef
-
-        if db is None:
-            from embedchain.vectordb.chroma_db import ChromaDB
-
-            self.db = ChromaDB(ef=self.ef, host=host, port=port)
-        else:
-            self.db = db
-
         self.ef = ef
         self.ef = ef
         self.db = db
         self.db = db
+
+        self.host = host
+        self.port = port
         return
         return
 
 
     def _set_embedding_function(self, ef):
     def _set_embedding_function(self, ef):
@@ -78,8 +62,7 @@ class InitConfig(BaseConfig):
         Sets database to default (`ChromaDb`).
         Sets database to default (`ChromaDb`).
         """
         """
         from embedchain.vectordb.chroma_db import ChromaDB
         from embedchain.vectordb.chroma_db import ChromaDB
-
-        self.db = ChromaDB(ef=self.ef)
+        self.db = ChromaDB(ef=self.ef, host=self.host, port=self.port)
 
 
     def _setup_logging(self, debug_level):
     def _setup_logging(self, debug_level):
         level = logging.WARNING  # Default level
         level = logging.WARNING  # Default level

+ 2 - 0
embedchain/vectordb/chroma_db.py

@@ -1,4 +1,5 @@
 import os
 import os
+import logging
 
 
 import chromadb
 import chromadb
 from chromadb.utils import embedding_functions
 from chromadb.utils import embedding_functions
@@ -20,6 +21,7 @@ class ChromaDB(BaseVectorDB):
             )
             )
 
 
         if host and port:
         if host and port:
+            logging.info(f"Connecting to ChromaDB server: {host}:{port}")
             self.client_settings = chromadb.config.Settings(
             self.client_settings = chromadb.config.Settings(
                 chroma_api_impl="rest",
                 chroma_api_impl="rest",
                 chroma_server_host=host,
                 chroma_server_host=host,