Ver código fonte

refactor: Add config for init, app and query (#158)

cachho 2 anos atrás
pai
commit
e50c7e6843

+ 58 - 0
README.md

@@ -265,6 +265,64 @@ _The embedding is confirmed to work as expected. It returns the right document,
 
 **The dry run will still consume tokens to embed your query, but it is only ~1/15 of the prompt.**
 
+# Advanced
+
+## Configuration
+Embedchain is made to work out of the box. However, for advanced users we're also offering configuration options. All of these configuration options are optional and have sane defaults.
+
+### Example
+
+Here's the readme example with configuration options.
+
+```python
+import os
+from embedchain import App
+from embedchain.config import InitConfig, AddConfig, QueryConfig
+from chromadb.utils import embedding_functions
+
+# Example: use your own embedding function
+config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
+                api_key=os.getenv("OPENAI_API_KEY"),
+                organization_id=os.getenv("OPENAI_ORGANIZATION"),
+                model_name="text-embedding-ada-002"
+            ))
+naval_chat_bot = App(config)
+
+add_config = AddConfig() # Currently no options
+naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", add_config)
+naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
+naval_chat_bot.add("web_page", "https://nav.al/feedback", add_config)
+naval_chat_bot.add("web_page", "https://nav.al/agi", add_config)
+
+naval_chat_bot.add_local("qna_pair", ("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."), add_config)
+
+query_config = QueryConfig() # Currently no options
+print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", query_config))
+```
+
+### Configs
+This section describes all possible config options.
+
+#### **InitConfig**
+|option|description|type|default|
+|---|---|---|---|
+|ef|embedding function|chromadb.utils.embedding_functions|{text-embedding-ada-002}|
+|db|vector database (experimental)|BaseVectorDB|ChromaDB|
+
+#### **Add Config**
+
+*coming soon*
+
+#### **Query Config**
+
+*coming soon*
+
+#### **Chat Config**
+
+All options for query and...
+
+*coming soon*
+
 # How does it work?
 
 Creating a chat bot over any dataset needs the following steps to happen

+ 8 - 0
embedchain/config/AddConfig.py

@@ -0,0 +1,8 @@
+from embedchain.config.BaseConfig import BaseConfig
+
+class AddConfig(BaseConfig):
+    """
+    Config for the `add` method.
+    """
+    def __init__(self):
+        pass

+ 9 - 0
embedchain/config/BaseConfig.py

@@ -0,0 +1,9 @@
+class BaseConfig:
+    """
+    Base config.
+    """
+    def __init__(self):
+        pass
+
+    def as_dict(self):
+        return vars(self)

+ 8 - 0
embedchain/config/ChatConfig.py

@@ -0,0 +1,8 @@
+from embedchain.config.QueryConfig import QueryConfig
+
+class ChatConfig(QueryConfig):
+    """
+    Config for the `chat` method, inherits from `QueryConfig`.
+    """
+    def __init__(self):
+        pass

+ 36 - 0
embedchain/config/InitConfig.py

@@ -0,0 +1,36 @@
+import os
+
+from embedchain.config.BaseConfig import BaseConfig
+
+class InitConfig(BaseConfig):
+    """
+    Config to initialize an embedchain `App` instance.
+    """
+    def __init__(self, ef=None, db=None):
+        """
+        :param ef: Optional. Embedding function to use.
+        :param db: Optional. (Vector) database to use for embeddings.
+        """
+        # Embedding Function
+        if ef is None:
+            from chromadb.utils import embedding_functions
+            self.ef = embedding_functions.OpenAIEmbeddingFunction(
+                api_key=os.getenv("OPENAI_API_KEY"),
+                organization_id=os.getenv("OPENAI_ORGANIZATION"),
+                model_name="text-embedding-ada-002"
+            )
+        else:
+            self.ef = ef
+
+        if db is None:
+            from embedchain.vectordb.chroma_db import ChromaDB
+            self.db = ChromaDB(ef=self.ef)
+        else:
+            self.db = db
+
+        return
+
+
+    def _set_embedding_function(self, ef):
+        self.ef = ef
+        return        

+ 8 - 0
embedchain/config/QueryConfig.py

@@ -0,0 +1,8 @@
+from embedchain.config.BaseConfig import BaseConfig
+
+class QueryConfig(BaseConfig):
+    """
+    Config for the `query` method.
+    """
+    def __init__(self):
+        pass

+ 5 - 0
embedchain/config/__init__.py

@@ -0,0 +1,5 @@
+from .BaseConfig import BaseConfig
+from .AddConfig import AddConfig
+from .ChatConfig import ChatConfig
+from .InitConfig import InitConfig
+from .QueryConfig import QueryConfig

+ 51 - 23
embedchain/embedchain.py

@@ -6,6 +6,7 @@ from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.embeddings.openai import OpenAIEmbeddings
 from langchain.memory import ConversationBufferMemory
+from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
 
 from embedchain.loaders.youtube_video import YoutubeVideoLoader
 from embedchain.loaders.pdf_file import PdfFileLoader
@@ -33,17 +34,17 @@ memory = ConversationBufferMemory()
 
 
 class EmbedChain:
-    def __init__(self, db=None, ef=None):
+    def __init__(self, config: InitConfig):
         """
         Initializes the EmbedChain instance, sets up a vector DB client and
         creates a collection.
 
-        :param db: The instance of the VectorDB subclass.
+        :param config: InitConfig instance to load as configuration.
         """
-        if db is None:
-            db = ChromaDB(ef=ef)
-        self.db_client = db.client
-        self.collection = db.collection
+        
+        self.config = config
+        self.db_client = self.config.db.client
+        self.collection = self.config.db.collection
         self.user_asks = []
 
     def _get_loader(self, data_type):
@@ -86,7 +87,7 @@ class EmbedChain:
         else:
             raise ValueError(f"Unsupported data type: {data_type}")
 
-    def add(self, data_type, url):
+    def add(self, data_type, url, config: AddConfig = None):
         """
         Adds the data from the given URL to the vector db.
         Loads the data, chunks it, create embedding for each chunk
@@ -94,13 +95,16 @@ class EmbedChain:
 
         :param data_type: The type of the data to add.
         :param url: The URL where the data is located.
+        :param config: Optional. The `AddConfig` instance to use as configuration options.
         """
+        if config is None:
+            config = AddConfig()
         loader = self._get_loader(data_type)
         chunker = self._get_chunker(data_type)
         self.user_asks.append([data_type, url])
         self.load_and_embed(loader, chunker, url)
 
-    def add_local(self, data_type, content):
+    def add_local(self, data_type, content, config: AddConfig = None):
         """
         Adds the data you supply to the vector db.
         Loads the data, chunks it, create embedding for each chunk
@@ -108,7 +112,10 @@ class EmbedChain:
 
         :param data_type: The type of the data to add.
         :param content: The local data. Refer to the `README` for formatting.
+        :param config: Optional. The `AddConfig` instance to use as configuration options.
         """
+        if config is None:
+            config = AddConfig()
         loader = self._get_loader(data_type)
         chunker = self._get_chunker(data_type)
         self.user_asks.append([data_type, content])
@@ -210,15 +217,18 @@ class EmbedChain:
         answer = self.get_llm_model_answer(prompt)
         return answer
 
-    def query(self, input_query):
+    def query(self, input_query, config: QueryConfig = None):
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
         LLM as context to get the answer.
 
         :param input_query: The query to use.
+        :param config: Optional. The `QueryConfig` instance to use as configuration options.
         :return: The answer to the query.
         """
+        if config is None:
+            config = QueryConfig()
         context = self.retrieve_from_database(input_query)
         prompt = self.generate_prompt(input_query, context)
         answer = self.get_answer_from_llm(prompt)
@@ -243,14 +253,19 @@ class EmbedChain:
         prompt += suffix_prompt
         return prompt
 
-    def chat(self, input_query):
+    def chat(self, input_query, config: ChatConfig = None):
         """
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         LLM as context to get the answer.
 
         Maintains last 5 conversations in memory.
+        :param input_query: The query to use.
+        :param config: Optional. The `ChatConfig` instance to use as configuration options.
+        :return: The answer to the query.
         """
+        if config is None:
+            config = ChatConfig()
         context = self.retrieve_from_database(input_query)
         global memory
         chat_history = memory.load_memory_variables({})["history"]
@@ -274,8 +289,11 @@ class EmbedChain:
         the `max_tokens` parameter.
 
         :param input_query: The query to use.
+        :param config: Optional. The `QueryConfig` instance to use as configuration options.
         :return: The prompt that would be sent to the LLM
         """
+        if config is None:
+            config = QueryConfig()
         context = self.retrieve_from_database(input_query)
         prompt = self.generate_prompt(input_query, context)
         return prompt
@@ -291,14 +309,13 @@ class App(EmbedChain):
     dry_run(query): test your prompt without consuming tokens.
     """
 
-    def __int__(self, db=None, ef=None):
-        if ef is None:
-            ef = embedding_functions.OpenAIEmbeddingFunction(
-                api_key=os.getenv("OPENAI_API_KEY"),
-                organization_id=os.getenv("OPENAI_ORGANIZATION"),
-                model_name="text-embedding-ada-002"
-            )
-        super().__init__(db, ef)
+    def __init__(self, config: InitConfig = None):
+        """
+        :param config: InitConfig instance to load as configuration. Optional.
+        """
+        if config is None:
+            config = InitConfig()
+        super().__init__(config)
 
     def get_llm_model_answer(self, prompt):
         messages = []
@@ -326,14 +343,25 @@ class OpenSourceApp(EmbedChain):
     query(query): finds answer to the given query using vector database and LLM.
     """
 
-    def __init__(self, db=None, ef=None):
+    def __init__(self, config: InitConfig = None):
+        """
+        :param config: InitConfig instance to load as configuration. Optional. `ef` defaults to open source.
+        """
         print("Loading open source embedding model. This may take some time...")
-        if ef is None:
-            ef = embedding_functions.SentenceTransformerEmbeddingFunction(
+        if not config or not config.ef:
+            if config is None:
+                config = InitConfig(
+                    ef = embedding_functions.SentenceTransformerEmbeddingFunction(
+                        model_name="all-MiniLM-L6-v2"
+                    )
+                )
+            else:
+                config._set_embedding_function(
+                    embedding_functions.SentenceTransformerEmbeddingFunction(
                 model_name="all-MiniLM-L6-v2"
-            )
+            ))
         print("Successfully loaded open source embedding model.")
-        super().__init__(db, ef)
+        super().__init__(config)
 
     def get_llm_model_answer(self, prompt):
         from gpt4all import GPT4All