|
@@ -6,6 +6,7 @@ from dotenv import load_dotenv
|
|
|
from langchain.docstore.document import Document
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
+from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
|
|
|
|
|
|
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
|
|
from embedchain.loaders.pdf_file import PdfFileLoader
|
|
@@ -33,17 +34,17 @@ memory = ConversationBufferMemory()
|
|
|
|
|
|
|
|
|
class EmbedChain:
|
|
|
- def __init__(self, db=None, ef=None):
|
|
|
+ def __init__(self, config: InitConfig):
|
|
|
"""
|
|
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
|
|
creates a collection.
|
|
|
|
|
|
- :param db: The instance of the VectorDB subclass.
|
|
|
+ :param config: InitConfig instance to load as configuration.
|
|
|
"""
|
|
|
- if db is None:
|
|
|
- db = ChromaDB(ef=ef)
|
|
|
- self.db_client = db.client
|
|
|
- self.collection = db.collection
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+ self.db_client = self.config.db.client
|
|
|
+ self.collection = self.config.db.collection
|
|
|
self.user_asks = []
|
|
|
|
|
|
def _get_loader(self, data_type):
|
|
@@ -86,7 +87,7 @@ class EmbedChain:
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported data type: {data_type}")
|
|
|
|
|
|
- def add(self, data_type, url):
|
|
|
+ def add(self, data_type, url, config: AddConfig = None):
|
|
|
"""
|
|
|
Adds the data from the given URL to the vector db.
|
|
|
Loads the data, chunks it, create embedding for each chunk
|
|
@@ -94,13 +95,16 @@ class EmbedChain:
|
|
|
|
|
|
:param data_type: The type of the data to add.
|
|
|
:param url: The URL where the data is located.
|
|
|
+ :param config: Optional. The `AddConfig` instance to use as configuration options.
|
|
|
"""
|
|
|
+ if config is None:
|
|
|
+ config = AddConfig()
|
|
|
loader = self._get_loader(data_type)
|
|
|
chunker = self._get_chunker(data_type)
|
|
|
self.user_asks.append([data_type, url])
|
|
|
self.load_and_embed(loader, chunker, url)
|
|
|
|
|
|
- def add_local(self, data_type, content):
|
|
|
+ def add_local(self, data_type, content, config: AddConfig = None):
|
|
|
"""
|
|
|
Adds the data you supply to the vector db.
|
|
|
Loads the data, chunks it, create embedding for each chunk
|
|
@@ -108,7 +112,10 @@ class EmbedChain:
|
|
|
|
|
|
:param data_type: The type of the data to add.
|
|
|
:param content: The local data. Refer to the `README` for formatting.
|
|
|
+ :param config: Optional. The `AddConfig` instance to use as configuration options.
|
|
|
"""
|
|
|
+ if config is None:
|
|
|
+ config = AddConfig()
|
|
|
loader = self._get_loader(data_type)
|
|
|
chunker = self._get_chunker(data_type)
|
|
|
self.user_asks.append([data_type, content])
|
|
@@ -210,15 +217,18 @@ class EmbedChain:
|
|
|
answer = self.get_llm_model_answer(prompt)
|
|
|
return answer
|
|
|
|
|
|
- def query(self, input_query):
|
|
|
+ def query(self, input_query, config: QueryConfig = None):
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|
|
|
Gets relevant doc based on the query and then passes it to an
|
|
|
LLM as context to get the answer.
|
|
|
|
|
|
:param input_query: The query to use.
|
|
|
+ :param config: Optional. The `QueryConfig` instance to use as configuration options.
|
|
|
:return: The answer to the query.
|
|
|
"""
|
|
|
+ if config is None:
|
|
|
+ config = QueryConfig()
|
|
|
context = self.retrieve_from_database(input_query)
|
|
|
prompt = self.generate_prompt(input_query, context)
|
|
|
answer = self.get_answer_from_llm(prompt)
|
|
@@ -243,14 +253,19 @@ class EmbedChain:
|
|
|
prompt += suffix_prompt
|
|
|
return prompt
|
|
|
|
|
|
- def chat(self, input_query):
|
|
|
+ def chat(self, input_query, config: ChatConfig = None):
|
|
|
"""
|
|
|
Queries the vector database on the given input query.
|
|
|
Gets relevant doc based on the query and then passes it to an
|
|
|
LLM as context to get the answer.
|
|
|
|
|
|
Maintains last 5 conversations in memory.
|
|
|
+ :param input_query: The query to use.
|
|
|
+ :param config: Optional. The `ChatConfig` instance to use as configuration options.
|
|
|
+ :return: The answer to the query.
|
|
|
"""
|
|
|
+ if config is None:
|
|
|
+ config = ChatConfig()
|
|
|
context = self.retrieve_from_database(input_query)
|
|
|
global memory
|
|
|
chat_history = memory.load_memory_variables({})["history"]
|
|
@@ -274,8 +289,11 @@ class EmbedChain:
|
|
|
the `max_tokens` parameter.
|
|
|
|
|
|
:param input_query: The query to use.
|
|
|
+ :param config: Optional. The `QueryConfig` instance to use as configuration options.
|
|
|
:return: The prompt that would be sent to the LLM
|
|
|
"""
|
|
|
+ if config is None:
|
|
|
+ config = QueryConfig()
|
|
|
context = self.retrieve_from_database(input_query)
|
|
|
prompt = self.generate_prompt(input_query, context)
|
|
|
return prompt
|
|
@@ -291,14 +309,13 @@ class App(EmbedChain):
|
|
|
dry_run(query): test your prompt without consuming tokens.
|
|
|
"""
|
|
|
|
|
|
- def __int__(self, db=None, ef=None):
|
|
|
- if ef is None:
|
|
|
- ef = embedding_functions.OpenAIEmbeddingFunction(
|
|
|
- api_key=os.getenv("OPENAI_API_KEY"),
|
|
|
- organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
|
|
- model_name="text-embedding-ada-002"
|
|
|
- )
|
|
|
- super().__init__(db, ef)
|
|
|
+ def __init__(self, config: InitConfig = None):
|
|
|
+ """
|
|
|
+ :param config: InitConfig instance to load as configuration. Optional.
|
|
|
+ """
|
|
|
+ if config is None:
|
|
|
+ config = InitConfig()
|
|
|
+ super().__init__(config)
|
|
|
|
|
|
def get_llm_model_answer(self, prompt):
|
|
|
messages = []
|
|
@@ -326,14 +343,25 @@ class OpenSourceApp(EmbedChain):
|
|
|
query(query): finds answer to the given query using vector database and LLM.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, db=None, ef=None):
|
|
|
+ def __init__(self, config: InitConfig = None):
|
|
|
+ """
|
|
|
+ :param config: InitConfig instance to load as configuration. Optional. `ef` defaults to open source.
|
|
|
+ """
|
|
|
print("Loading open source embedding model. This may take some time...")
|
|
|
- if ef is None:
|
|
|
- ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
+ if not config or not config.ef:
|
|
|
+ if config is None:
|
|
|
+ config = InitConfig(
|
|
|
+ ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
+ model_name="all-MiniLM-L6-v2"
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ config._set_embedding_function(
|
|
|
+ embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
model_name="all-MiniLM-L6-v2"
|
|
|
- )
|
|
|
+ ))
|
|
|
print("Successfully loaded open source embedding model.")
|
|
|
- super().__init__(db, ef)
|
|
|
+ super().__init__(config)
|
|
|
|
|
|
def get_llm_model_answer(self, prompt):
|
|
|
from gpt4all import GPT4All
|