cachho před 2 roky
rodič
revize
9409e5605a
3 změnil soubory, kde provedl 25 přidání a 23 odebrání
  1. 17 10
      embedchain/config/InitConfig.py
  2. 7 6
      embedchain/embedchain.py
  3. 1 7
      setup.py

+ 17 - 10
embedchain/config/InitConfig.py

@@ -46,25 +46,31 @@ class InitConfig(BaseConfig):
     def _set_embedding_function(self, ef):
         self.ef = ef
         return
-    
+
     def _set_embedding_function_to_default(self):
         """
         Sets embedding function to default (`text-embedding-ada-002`).
 
-        :raises ValueError: If the template is not valid as template should contain $context and $query
+        :raises ValueError: If the template is not valid as template should contain
+        $context and $query
         """
-        if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_ORGANIZATION") is None:
-            raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided")
-        self.ef = embedding_functions.OpenAIEmbeddingFunction(
-                api_key=os.getenv("OPENAI_API_KEY"),
-                organization_id=os.getenv("OPENAI_ORGANIZATION"),
-                model_name="text-embedding-ada-002"
+        if (
+            os.getenv("OPENAI_API_KEY") is None
+            or os.getenv("OPENAI_ORGANIZATION") is None
+        ):
+            raise ValueError(
+                "OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"  # noqa:E501
             )
+        self.ef = embedding_functions.OpenAIEmbeddingFunction(
+            api_key=os.getenv("OPENAI_API_KEY"),
+            organization_id=os.getenv("OPENAI_ORGANIZATION"),
+            model_name="text-embedding-ada-002",
+        )
         return
-    
+
     def _set_db(self, db):
         if db:
-            self.db = db            
+            self.db = db
         return
 
     def _set_db_to_default(self):
@@ -72,6 +78,7 @@ class InitConfig(BaseConfig):
         Sets database to default (`ChromaDb`).
         """
         from embedchain.vectordb.chroma_db import ChromaDB
+
         self.db = ChromaDB(ef=self.ef)
 
     def _setup_logging(self, debug_level):

+ 7 - 6
embedchain/embedchain.py

@@ -301,13 +301,13 @@ class App(EmbedChain):
         """
         if config is None:
             config = InitConfig()
-        
+
         if not config.ef:
             config._set_embedding_function_to_default()
 
         if not config.db:
             config._set_db_to_default()
-        
+
         super().__init__(config)
 
     def get_llm_model_answer(self, prompt, config: ChatConfig):
@@ -357,12 +357,13 @@ class OpenSourceApp(EmbedChain):
         )  # noqa:E501
         if not config:
             config = InitConfig()
-        
+
         if not config.ef:
             config._set_embedding_function(
-                    embedding_functions.SentenceTransformerEmbeddingFunction(
-                model_name="all-MiniLM-L6-v2"
-            ))
+                embedding_functions.SentenceTransformerEmbeddingFunction(
+                    model_name="all-MiniLM-L6-v2"
+                )
+            )
 
         if not config.db:
             config._set_db_to_default()

+ 1 - 7
setup.py

@@ -34,11 +34,5 @@ setuptools.setup(
         "docx2txt",
         "pydantic==1.10.8",
     ],
-    extras_require={
-        "dev": [
-            "black",
-            "ruff",
-            "isort"
-        ]
-    }
+    extras_require={"dev": ["black", "ruff", "isort"]},
 )