Przeglądaj źródła

Open source embedding and LLM models (#133)

* Add open source LLM model: gpt4all
* Add open source embedding model: sentence transformers
Taranjeet Singh 2 lat temu
rodzic
commit
cf1e000fb3
4 zmienionych plików z 71 dodań i 24 usunięć
  1. 1 1
      embedchain/__init__.py
  2. 65 21
      embedchain/embedchain.py
  3. 3 2
      embedchain/vectordb/chroma_db.py
  4. 2 0
      setup.py

+ 1 - 1
embedchain/__init__.py

@@ -1 +1 @@
-from .embedchain import App
+from .embedchain import App, OpenSourceApp

+ 65 - 21
embedchain/embedchain.py

@@ -1,7 +1,9 @@
 import openai
 import os
 
+from chromadb.utils import embedding_functions
 from dotenv import load_dotenv
+from gpt4all import GPT4All
 from langchain.docstore.document import Document
 from langchain.embeddings.openai import OpenAIEmbeddings
 
@@ -17,16 +19,23 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.vectordb.chroma_db import ChromaDB
 
-load_dotenv()
+openai_ef = embedding_functions.OpenAIEmbeddingFunction(
+    api_key=os.getenv("OPENAI_API_KEY"),
+    organization_id=os.getenv("OPENAI_ORGANIZATION"),
+    model_name="text-embedding-ada-002"
+)
+sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
+
+gpt4all_model = None
 
-embeddings = OpenAIEmbeddings()
+load_dotenv()
 
 ABS_PATH = os.getcwd()
 DB_DIR = os.path.join(ABS_PATH, "db")
 
 
 class EmbedChain:
-    def __init__(self, db=None):
+    def __init__(self, db=None, ef=None):
         """
         Initializes the EmbedChain instance, sets up a vector DB client and
         creates a collection.
@@ -34,7 +43,7 @@ class EmbedChain:
         :param db: The instance of the VectorDB subclass.
         """
         if db is None:
-            db = ChromaDB()
+            db = ChromaDB(ef=ef)
         self.db_client = db.client
         self.collection = db.collection
         self.user_asks = []
@@ -154,20 +163,9 @@ class EmbedChain:
             )
         ]
 
-    def get_openai_answer(self, prompt):
-        messages = []
-        messages.append({
-            "role": "user", "content": prompt
-        })
-        response = openai.ChatCompletion.create(
-            model="gpt-3.5-turbo-0613",
-            messages=messages,
-            temperature=0,
-            max_tokens=1000,
-            top_p=1,
-        )
-        return response["choices"][0]["message"]["content"]
-    
+    def get_llm_model_answer(self, prompt):
+        raise NotImplementedError
+
     def retrieve_from_database(self, input_query):
         """
         Queries the vector database based on the given input query.
@@ -186,7 +184,7 @@ class EmbedChain:
         else:
             content = ""
         return content
-    
+
     def generate_prompt(self, input_query, context):
         """
         Generates a prompt based on the given query and context, ready to be passed to an LLM
@@ -211,7 +209,7 @@ class EmbedChain:
         :param context: Similar documents to the query used as context.
         :return: The answer.
         """
-        answer = self.get_openai_answer(prompt)
+        answer = self.get_llm_model_answer(prompt)
         return answer
 
     def query(self, input_query):
@@ -237,4 +235,50 @@ class App(EmbedChain):
     adds(data_type, url): adds the data from the given URL to the vector db.
     query(query): finds answer to the given query using vector database and LLM.
     """
-    pass
+
+    def __int__(self, db=None, ef=None):
+        if ef is None:
+            ef = openai_ef
+        super().__init__(db, ef)
+
+    def get_llm_model_answer(self, prompt):
+        messages = []
+        messages.append({
+            "role": "user", "content": prompt
+        })
+        response = openai.ChatCompletion.create(
+            model="gpt-3.5-turbo-0613",
+            messages=messages,
+            temperature=0,
+            max_tokens=1000,
+            top_p=1,
+        )
+        return response["choices"][0]["message"]["content"]
+
+
+class OpenSourceApp(EmbedChain):
+    """
+    The OpenSource app.
+    Same as App, but uses an open source embedding model and LLM.
+
+    Has two function: add and query.
+
+    adds(data_type, url): adds the data from the given URL to the vector db.
+    query(query): finds answer to the given query using vector database and LLM.
+    """
+
+    def __init__(self, db=None, ef=None):
+        print("Loading open source embedding model. This may take some time...")
+        if ef is None:
+            ef = sentence_transformer_ef
+        print("Successfully loaded open source embedding model.")
+        super().__init__(db, ef)
+
+    def get_llm_model_answer(self, prompt):
+        global gpt4all_model
+        if gpt4all_model is None:
+            gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
+        response = gpt4all_model.generate(
+            prompt=prompt,
+        )
+        return response

+ 3 - 2
embedchain/vectordb/chroma_db.py

@@ -12,7 +12,8 @@ openai_ef = embedding_functions.OpenAIEmbeddingFunction(
 )
 
 class ChromaDB(BaseVectorDB):
-    def __init__(self, db_dir=None):
+    def __init__(self, db_dir=None, ef=None):
+        self.ef = ef if ef is not None else openai_ef
         if db_dir is None:
             db_dir = "db"
         self.client_settings = chromadb.config.Settings(
@@ -27,5 +28,5 @@ class ChromaDB(BaseVectorDB):
 
     def _get_or_create_collection(self):
         return self.client.get_or_create_collection(
-            'embedchain_store', embedding_function=openai_ef,
+            'embedchain_store', embedding_function=self.ef,
         )

+ 2 - 0
setup.py

@@ -29,5 +29,7 @@ setuptools.setup(
         "beautifulsoup4",
         "pypdf",
         "pytube",
+        "gpt4all",
+        "sentence_transformers",
     ]
 )