|
@@ -133,43 +133,44 @@ class EmbedChain:
|
|
|
def get_llm_model_answer(self, prompt):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
- def retrieve_from_database(self, input_query):
|
|
|
+ def retrieve_from_database(self, input_query, config: QueryConfig):
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|
|
|
Gets relevant doc based on the query
|
|
|
|
|
|
:param input_query: The query to use.
|
|
|
+ :param config: The query configuration.
|
|
|
:return: The content of the document that matched your query.
|
|
|
"""
|
|
|
result = self.collection.query(
|
|
|
query_texts=[
|
|
|
input_query,
|
|
|
],
|
|
|
- n_results=1,
|
|
|
+ n_results=config.number_documents,
|
|
|
)
|
|
|
- result_formatted = self._format_result(result)
|
|
|
- if result_formatted:
|
|
|
- content = result_formatted[0][0].page_content
|
|
|
- else:
|
|
|
- content = ""
|
|
|
- return content
|
|
|
+ results_formatted = self._format_result(result)
|
|
|
+ contents = [result[0].page_content for result in results_formatted]
|
|
|
+ return contents
|
|
|
|
|
|
- def generate_prompt(self, input_query, context, config: QueryConfig):
|
|
|
+ def generate_prompt(self, input_query, contexts, config: QueryConfig):
|
|
|
"""
|
|
|
Generates a prompt based on the given query and context, ready to be
|
|
|
passed to an LLM
|
|
|
|
|
|
:param input_query: The query to use.
|
|
|
- :param context: Similar documents to the query used as context.
|
|
|
+ :param contexts: List of similar documents to the query used as context.
|
|
|
:param config: Optional. The `QueryConfig` instance to use as
|
|
|
configuration options.
|
|
|
:return: The prompt
|
|
|
"""
|
|
|
+ context_string = (" | ").join(contexts)
|
|
|
if not config.history:
|
|
|
- prompt = config.template.substitute(context=context, query=input_query)
|
|
|
+ prompt = config.template.substitute(
|
|
|
+ context=context_string, query=input_query
|
|
|
+ )
|
|
|
else:
|
|
|
prompt = config.template.substitute(
|
|
|
- context=context, query=input_query, history=config.history
|
|
|
+ context=context_string, query=input_query, history=config.history
|
|
|
)
|
|
|
return prompt
|
|
|
|
|
@@ -198,8 +199,8 @@ class EmbedChain:
|
|
|
"""
|
|
|
if config is None:
|
|
|
config = QueryConfig()
|
|
|
- context = self.retrieve_from_database(input_query)
|
|
|
- prompt = self.generate_prompt(input_query, context, config)
|
|
|
+ contexts = self.retrieve_from_database(input_query, config)
|
|
|
+ prompt = self.generate_prompt(input_query, contexts, config)
|
|
|
logging.info(f"Prompt: {prompt}")
|
|
|
answer = self.get_answer_from_llm(prompt, config)
|
|
|
logging.info(f"Answer: {answer}")
|
|
@@ -217,16 +218,18 @@ class EmbedChain:
|
|
|
configuration options.
|
|
|
:return: The answer to the query.
|
|
|
"""
|
|
|
- context = self.retrieve_from_database(input_query)
|
|
|
+ if config is None:
|
|
|
+ config = ChatConfig()
|
|
|
+
|
|
|
+ contexts = self.retrieve_from_database(input_query, config)
|
|
|
+
|
|
|
global memory
|
|
|
chat_history = memory.load_memory_variables({})["history"]
|
|
|
|
|
|
- if config is None:
|
|
|
- config = ChatConfig()
|
|
|
if chat_history:
|
|
|
config.set_history(chat_history)
|
|
|
|
|
|
- prompt = self.generate_prompt(input_query, context, config)
|
|
|
+ prompt = self.generate_prompt(input_query, contexts, config)
|
|
|
logging.info(f"Prompt: {prompt}")
|
|
|
answer = self.get_answer_from_llm(prompt, config)
|
|
|
|
|
@@ -264,8 +267,8 @@ class EmbedChain:
|
|
|
"""
|
|
|
if config is None:
|
|
|
config = QueryConfig()
|
|
|
- context = self.retrieve_from_database(input_query)
|
|
|
- prompt = self.generate_prompt(input_query, context, config)
|
|
|
+ contexts = self.retrieve_from_database(input_query, config)
|
|
|
+ prompt = self.generate_prompt(input_query, contexts, config)
|
|
|
logging.info(f"Prompt: {prompt}")
|
|
|
return prompt
|
|
|
|