Преглед на файлове

fix: Fix dependency of openai env variables for OpenSourceApp (#144)

This commit fixes dependency of initializing openai env variables
for OpenSourceApp.
Taranjeet Singh преди 2 години
родител
ревизия
200f11a0e0
променени са 2 файла, в които са добавени 16 реда и са изтрити 14 реда
  1. 8 8
      embedchain/embedchain.py
  2. 8 6
      embedchain/vectordb/chroma_db.py

+ 8 - 8
embedchain/embedchain.py

@@ -19,12 +19,6 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.vectordb.chroma_db import ChromaDB
 
-openai_ef = embedding_functions.OpenAIEmbeddingFunction(
-    api_key=os.getenv("OPENAI_API_KEY"),
-    organization_id=os.getenv("OPENAI_ORGANIZATION"),
-    model_name="text-embedding-ada-002"
-)
-sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
 
 gpt4all_model = None
 
@@ -238,7 +232,11 @@ class App(EmbedChain):
 
     def __int__(self, db=None, ef=None):
         if ef is None:
-            ef = openai_ef
+            ef = embedding_functions.OpenAIEmbeddingFunction(
+                api_key=os.getenv("OPENAI_API_KEY"),
+                organization_id=os.getenv("OPENAI_ORGANIZATION"),
+                model_name="text-embedding-ada-002"
+            )
         super().__init__(db, ef)
 
     def get_llm_model_answer(self, prompt):
@@ -270,7 +268,9 @@ class OpenSourceApp(EmbedChain):
     def __init__(self, db=None, ef=None):
         print("Loading open source embedding model. This may take some time...")
         if ef is None:
-            ef = sentence_transformer_ef
+            ef = embedding_functions.SentenceTransformerEmbeddingFunction(
+                model_name="all-MiniLM-L6-v2"
+            )
         print("Successfully loaded open source embedding model.")
         super().__init__(db, ef)
 

+ 8 - 6
embedchain/vectordb/chroma_db.py

@@ -5,15 +5,17 @@ from chromadb.utils import embedding_functions
 
 from embedchain.vectordb.base_vector_db import BaseVectorDB
 
-openai_ef = embedding_functions.OpenAIEmbeddingFunction(
-    api_key=os.getenv("OPENAI_API_KEY"),
-    organization_id=os.getenv("OPENAI_ORGANIZATION"),
-    model_name="text-embedding-ada-002"
-)
 
 class ChromaDB(BaseVectorDB):
     def __init__(self, db_dir=None, ef=None):
-        self.ef = ef if ef is not None else openai_ef
+        if ef:
+            self.ef = ef
+        else:
+            self.ef = embedding_functions.OpenAIEmbeddingFunction(
+                api_key=os.getenv("OPENAI_API_KEY"),
+                organization_id=os.getenv("OPENAI_ORGANIZATION"),
+                model_name="text-embedding-ada-002"
+            )
         if db_dir is None:
             db_dir = "db"
         self.client_settings = chromadb.config.Settings(