瀏覽代碼

feat: add multi-document answers (#63)

cachho 2 年之前
父節點
當前提交
40dc28406d
共有 3 個文件被更改,包括 32 次插入21 次删除
  1. 2 1
      README.md
  2. 7 0
      embedchain/config/QueryConfig.py
  3. 23 20
      embedchain/embedchain.py

+ 2 - 1
README.md

@@ -492,7 +492,8 @@ _coming soon_
 
 
 |option|description|type|default|
 |option|description|type|default|
 |---|---|---|---|
 |---|---|---|---|
-|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: \$query Helpful Answer:")|
+|number_documents|number of documents to be retrieved as context|int|1|
+|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: $query Helpful Answer:")|
 |history|include conversation history from your client or database|any (recommendation: list[str])|None
 |history|include conversation history from your client or database|any (recommendation: list[str])|None
 |stream|control if response is streamed back to the user|bool|False|
 |stream|control if response is streamed back to the user|bool|False|
 |model|OpenAI model|string|gpt-3.5-turbo-0613|
 |model|OpenAI model|string|gpt-3.5-turbo-0613|

+ 7 - 0
embedchain/config/QueryConfig.py

@@ -42,6 +42,7 @@ class QueryConfig(BaseConfig):
 
 
     def __init__(
     def __init__(
         self,
         self,
+        number_documents=None,
         template: Template = None,
         template: Template = None,
         model=None,
         model=None,
         temperature=None,
         temperature=None,
@@ -53,6 +54,7 @@ class QueryConfig(BaseConfig):
         """
         """
         Initializes the QueryConfig instance.
         Initializes the QueryConfig instance.
 
 
+        :param number_documents: Number of documents to pull from the database as context.
         :param template: Optional. The `Template` instance to use as a template for
         :param template: Optional. The `Template` instance to use as a template for
         prompt.
         prompt.
         :param model: Optional. Controls the OpenAI model used.
         :param model: Optional. Controls the OpenAI model used.
@@ -68,6 +70,11 @@ class QueryConfig(BaseConfig):
         :raises ValueError: If the template is not valid as template should
         :raises ValueError: If the template is not valid as template should
         contain $context and $query (and optionally $history).
         contain $context and $query (and optionally $history).
         """
         """
+        if number_documents is None:
+            self.number_documents = 1
+        else:
+            self.number_documents = number_documents
+
         if not history:
         if not history:
             self.history = None
             self.history = None
         else:
         else:

+ 23 - 20
embedchain/embedchain.py

@@ -133,43 +133,44 @@ class EmbedChain:
     def get_llm_model_answer(self, prompt):
     def get_llm_model_answer(self, prompt):
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def retrieve_from_database(self, input_query):
+    def retrieve_from_database(self, input_query, config: QueryConfig):
         """
         """
         Queries the vector database based on the given input query.
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
         Gets relevant doc based on the query
 
 
         :param input_query: The query to use.
         :param input_query: The query to use.
+        :param config: The query configuration.
         :return: The content of the document that matched your query.
         :return: The content of the document that matched your query.
         """
         """
         result = self.collection.query(
         result = self.collection.query(
             query_texts=[
             query_texts=[
                 input_query,
                 input_query,
             ],
             ],
-            n_results=1,
+            n_results=config.number_documents,
         )
         )
-        result_formatted = self._format_result(result)
-        if result_formatted:
-            content = result_formatted[0][0].page_content
-        else:
-            content = ""
-        return content
+        results_formatted = self._format_result(result)
+        contents = [result[0].page_content for result in results_formatted]
+        return contents
 
 
-    def generate_prompt(self, input_query, context, config: QueryConfig):
+    def generate_prompt(self, input_query, contexts, config: QueryConfig):
         """
         """
         Generates a prompt based on the given query and context, ready to be
         Generates a prompt based on the given query and context, ready to be
         passed to an LLM
         passed to an LLM
 
 
         :param input_query: The query to use.
         :param input_query: The query to use.
-        :param context: Similar documents to the query used as context.
+        :param contexts: List of similar documents to the query used as context.
         :param config: Optional. The `QueryConfig` instance to use as
         :param config: Optional. The `QueryConfig` instance to use as
         configuration options.
         configuration options.
         :return: The prompt
         :return: The prompt
         """
         """
+        context_string = (" | ").join(contexts)
         if not config.history:
         if not config.history:
-            prompt = config.template.substitute(context=context, query=input_query)
+            prompt = config.template.substitute(
+                context=context_string, query=input_query
+            )
         else:
         else:
             prompt = config.template.substitute(
             prompt = config.template.substitute(
-                context=context, query=input_query, history=config.history
+                context=context_string, query=input_query, history=config.history
             )
             )
         return prompt
         return prompt
 
 
@@ -198,8 +199,8 @@ class EmbedChain:
         """
         """
         if config is None:
         if config is None:
             config = QueryConfig()
             config = QueryConfig()
-        context = self.retrieve_from_database(input_query)
-        prompt = self.generate_prompt(input_query, context, config)
+        contexts = self.retrieve_from_database(input_query, config)
+        prompt = self.generate_prompt(input_query, contexts, config)
         logging.info(f"Prompt: {prompt}")
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
         answer = self.get_answer_from_llm(prompt, config)
         logging.info(f"Answer: {answer}")
         logging.info(f"Answer: {answer}")
@@ -217,16 +218,18 @@ class EmbedChain:
         configuration options.
         configuration options.
         :return: The answer to the query.
         :return: The answer to the query.
         """
         """
-        context = self.retrieve_from_database(input_query)
+        if config is None:
+            config = ChatConfig()
+
+        contexts = self.retrieve_from_database(input_query, config)
+
         global memory
         global memory
         chat_history = memory.load_memory_variables({})["history"]
         chat_history = memory.load_memory_variables({})["history"]
 
 
-        if config is None:
-            config = ChatConfig()
         if chat_history:
         if chat_history:
             config.set_history(chat_history)
             config.set_history(chat_history)
 
 
-        prompt = self.generate_prompt(input_query, context, config)
+        prompt = self.generate_prompt(input_query, contexts, config)
         logging.info(f"Prompt: {prompt}")
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
         answer = self.get_answer_from_llm(prompt, config)
 
 
@@ -264,8 +267,8 @@ class EmbedChain:
         """
         """
         if config is None:
         if config is None:
             config = QueryConfig()
             config = QueryConfig()
-        context = self.retrieve_from_database(input_query)
-        prompt = self.generate_prompt(input_query, context, config)
+        contexts = self.retrieve_from_database(input_query, config)
+        prompt = self.generate_prompt(input_query, contexts, config)
         logging.info(f"Prompt: {prompt}")
         logging.info(f"Prompt: {prompt}")
         return prompt
         return prompt