فهرست منبع

feat: add multi-document answers (#63)

cachho 2 سال پیش
والد
کامیت
40dc28406d
3فایلهای تغییر یافته به همراه32 افزوده شده و 21 حذف شده
  1. 2 1
      README.md
  2. 7 0
      embedchain/config/QueryConfig.py
  3. 23 20
      embedchain/embedchain.py

+ 2 - 1
README.md

@@ -492,7 +492,8 @@ _coming soon_
 
 |option|description|type|default|
 |---|---|---|---|
-|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: \$query Helpful Answer:")|
+|number_documents|number of documents to be retrieved as context|int|1|
+|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: $query Helpful Answer:")|
 |history|include conversation history from your client or database|any (recommendation: list[str])|None
 |stream|control if response is streamed back to the user|bool|False|
 |model|OpenAI model|string|gpt-3.5-turbo-0613|

+ 7 - 0
embedchain/config/QueryConfig.py

@@ -42,6 +42,7 @@ class QueryConfig(BaseConfig):
 
     def __init__(
         self,
+        number_documents=None,
         template: Template = None,
         model=None,
         temperature=None,
@@ -53,6 +54,7 @@ class QueryConfig(BaseConfig):
         """
         Initializes the QueryConfig instance.
 
+        :param number_documents: Number of documents to pull from the database as context.
         :param template: Optional. The `Template` instance to use as a template for
         prompt.
         :param model: Optional. Controls the OpenAI model used.
@@ -68,6 +70,11 @@ class QueryConfig(BaseConfig):
         :raises ValueError: If the template is not valid as template should
         contain $context and $query (and optionally $history).
         """
+        if number_documents is None:
+            self.number_documents = 1
+        else:
+            self.number_documents = number_documents
+
         if not history:
             self.history = None
         else:

+ 23 - 20
embedchain/embedchain.py

@@ -133,43 +133,44 @@ class EmbedChain:
     def get_llm_model_answer(self, prompt):
         raise NotImplementedError
 
-    def retrieve_from_database(self, input_query):
+    def retrieve_from_database(self, input_query, config: QueryConfig):
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
 
         :param input_query: The query to use.
+        :param config: The query configuration.
         :return: The content of the document that matched your query.
         """
         result = self.collection.query(
             query_texts=[
                 input_query,
             ],
-            n_results=1,
+            n_results=config.number_documents,
         )
-        result_formatted = self._format_result(result)
-        if result_formatted:
-            content = result_formatted[0][0].page_content
-        else:
-            content = ""
-        return content
+        results_formatted = self._format_result(result)
+        contents = [result[0].page_content for result in results_formatted]
+        return contents
 
-    def generate_prompt(self, input_query, context, config: QueryConfig):
+    def generate_prompt(self, input_query, contexts, config: QueryConfig):
         """
         Generates a prompt based on the given query and context, ready to be
         passed to an LLM
 
         :param input_query: The query to use.
-        :param context: Similar documents to the query used as context.
+        :param contexts: List of similar documents to the query used as context.
         :param config: Optional. The `QueryConfig` instance to use as
         configuration options.
         :return: The prompt
         """
+        context_string = (" | ").join(contexts)
         if not config.history:
-            prompt = config.template.substitute(context=context, query=input_query)
+            prompt = config.template.substitute(
+                context=context_string, query=input_query
+            )
         else:
             prompt = config.template.substitute(
-                context=context, query=input_query, history=config.history
+                context=context_string, query=input_query, history=config.history
             )
         return prompt
 
@@ -198,8 +199,8 @@ class EmbedChain:
         """
         if config is None:
             config = QueryConfig()
-        context = self.retrieve_from_database(input_query)
-        prompt = self.generate_prompt(input_query, context, config)
+        contexts = self.retrieve_from_database(input_query, config)
+        prompt = self.generate_prompt(input_query, contexts, config)
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
         logging.info(f"Answer: {answer}")
@@ -217,16 +218,18 @@ class EmbedChain:
         configuration options.
         :return: The answer to the query.
         """
-        context = self.retrieve_from_database(input_query)
+        if config is None:
+            config = ChatConfig()
+
+        contexts = self.retrieve_from_database(input_query, config)
+
         global memory
         chat_history = memory.load_memory_variables({})["history"]
 
-        if config is None:
-            config = ChatConfig()
         if chat_history:
             config.set_history(chat_history)
 
-        prompt = self.generate_prompt(input_query, context, config)
+        prompt = self.generate_prompt(input_query, contexts, config)
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
 
@@ -264,8 +267,8 @@ class EmbedChain:
         """
         if config is None:
             config = QueryConfig()
-        context = self.retrieve_from_database(input_query)
-        prompt = self.generate_prompt(input_query, context, config)
+        contexts = self.retrieve_from_database(input_query, config)
+        prompt = self.generate_prompt(input_query, contexts, config)
         logging.info(f"Prompt: {prompt}")
         return prompt