|
@@ -1,7 +1,9 @@
|
|
|
import openai
|
|
|
import os
|
|
|
|
|
|
+from chromadb.utils import embedding_functions
|
|
|
from dotenv import load_dotenv
|
|
|
+from gpt4all import GPT4All
|
|
|
from langchain.docstore.document import Document
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
@@ -17,16 +19,23 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
|
|
|
from embedchain.chunkers.text import TextChunker
|
|
|
from embedchain.vectordb.chroma_db import ChromaDB
|
|
|
|
|
|
-load_dotenv()
|
|
|
+openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
|
|
+ api_key=os.getenv("OPENAI_API_KEY"),
|
|
|
+ organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
|
|
+ model_name="text-embedding-ada-002"
|
|
|
+)
|
|
|
+sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
|
|
+
|
|
|
+gpt4all_model = None
|
|
|
|
|
|
-embeddings = OpenAIEmbeddings()
|
|
|
+load_dotenv()
|
|
|
|
|
|
ABS_PATH = os.getcwd()
|
|
|
DB_DIR = os.path.join(ABS_PATH, "db")
|
|
|
|
|
|
|
|
|
class EmbedChain:
|
|
|
- def __init__(self, db=None):
|
|
|
+ def __init__(self, db=None, ef=None):
|
|
|
"""
|
|
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
|
|
creates a collection.
|
|
@@ -34,7 +43,7 @@ class EmbedChain:
|
|
|
:param db: The instance of the VectorDB subclass.
|
|
|
"""
|
|
|
if db is None:
|
|
|
- db = ChromaDB()
|
|
|
+ db = ChromaDB(ef=ef)
|
|
|
self.db_client = db.client
|
|
|
self.collection = db.collection
|
|
|
self.user_asks = []
|
|
@@ -154,20 +163,9 @@ class EmbedChain:
|
|
|
)
|
|
|
]
|
|
|
|
|
|
- def get_openai_answer(self, prompt):
|
|
|
- messages = []
|
|
|
- messages.append({
|
|
|
- "role": "user", "content": prompt
|
|
|
- })
|
|
|
- response = openai.ChatCompletion.create(
|
|
|
- model="gpt-3.5-turbo-0613",
|
|
|
- messages=messages,
|
|
|
- temperature=0,
|
|
|
- max_tokens=1000,
|
|
|
- top_p=1,
|
|
|
- )
|
|
|
- return response["choices"][0]["message"]["content"]
|
|
|
-
|
|
|
+ def get_llm_model_answer(self, prompt):
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
def retrieve_from_database(self, input_query):
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|
|
@@ -186,7 +184,7 @@ class EmbedChain:
|
|
|
else:
|
|
|
content = ""
|
|
|
return content
|
|
|
-
|
|
|
+
|
|
|
def generate_prompt(self, input_query, context):
|
|
|
"""
|
|
|
Generates a prompt based on the given query and context, ready to be passed to an LLM
|
|
@@ -211,7 +209,7 @@ class EmbedChain:
|
|
|
:param context: Similar documents to the query used as context.
|
|
|
:return: The answer.
|
|
|
"""
|
|
|
- answer = self.get_openai_answer(prompt)
|
|
|
+ answer = self.get_llm_model_answer(prompt)
|
|
|
return answer
|
|
|
|
|
|
def query(self, input_query):
|
|
@@ -237,4 +235,50 @@ class App(EmbedChain):
|
|
|
adds(data_type, url): adds the data from the given URL to the vector db.
|
|
|
query(query): finds answer to the given query using vector database and LLM.
|
|
|
"""
|
|
|
- pass
|
|
|
+
|
|
|
+ def __int__(self, db=None, ef=None):
|
|
|
+ if ef is None:
|
|
|
+ ef = openai_ef
|
|
|
+ super().__init__(db, ef)
|
|
|
+
|
|
|
+ def get_llm_model_answer(self, prompt):
|
|
|
+ messages = []
|
|
|
+ messages.append({
|
|
|
+ "role": "user", "content": prompt
|
|
|
+ })
|
|
|
+ response = openai.ChatCompletion.create(
|
|
|
+ model="gpt-3.5-turbo-0613",
|
|
|
+ messages=messages,
|
|
|
+ temperature=0,
|
|
|
+ max_tokens=1000,
|
|
|
+ top_p=1,
|
|
|
+ )
|
|
|
+ return response["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+
|
|
|
+class OpenSourceApp(EmbedChain):
|
|
|
+ """
|
|
|
+ The OpenSource app.
|
|
|
+ Same as App, but uses an open source embedding model and LLM.
|
|
|
+
|
|
|
+ Has two function: add and query.
|
|
|
+
|
|
|
+ adds(data_type, url): adds the data from the given URL to the vector db.
|
|
|
+ query(query): finds answer to the given query using vector database and LLM.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, db=None, ef=None):
|
|
|
+ print("Loading open source embedding model. This may take some time...")
|
|
|
+ if ef is None:
|
|
|
+ ef = sentence_transformer_ef
|
|
|
+ print("Successfully loaded open source embedding model.")
|
|
|
+ super().__init__(db, ef)
|
|
|
+
|
|
|
+ def get_llm_model_answer(self, prompt):
|
|
|
+ global gpt4all_model
|
|
|
+ if gpt4all_model is None:
|
|
|
+ gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
|
|
+ response = gpt4all_model.generate(
|
|
|
+ prompt=prompt,
|
|
|
+ )
|
|
|
+ return response
|