Browse Source

Refactor query endpoint into 3 parts (#42)

Query endpoint now consists of 3 sub functions
- get data from db
- get prompt
- get answer from the data retrieved above by passing to LLM
cachho 2 years ago
parent
commit
cf99dce940
1 changed files with 35 additions and 12 deletions
  1. 35 12
      embedchain/embedchain.py

+ 35 - 12
embedchain/embedchain.py

@@ -163,21 +163,47 @@ class EmbedChain:
             top_p=1,
         )
         return response["choices"][0]["message"]["content"]
+    
+    def retrieve_from_database(self, input_query):
+        """
+        Queries the vector database based on the given input query.
+        Gets relevant doc based on the query
 
-    def get_answer_from_llm(self, query, context):
+        :param input_query: The query to use.
+        :return: The content of the document that matched your query.
         """
-        Gets an answer based on the given query and context by passing it
-        to an LLM.
+        result = self.collection.query(
+            query_texts=[input_query,],
+            n_results=1,
+        )
+        result_formatted = self._format_result(result)
+        content = result_formatted[0][0].page_content
+        return content
+    
+    def generate_prompt(self, input_query, context):
+        """
+        Generates a prompt based on the given query and context, ready to be passed to an LLM
 
-        :param query: The query to use.
+        :param input_query: The query to use.
         :param context: Similar documents to the query used as context.
-        :return: The answer.
+        :return: The prompt
         """
         prompt = f"""Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
         {context}
-        Query: {query}
+        Query: {input_query}
         Helpful Answer:
         """
+        return prompt
+
+    def get_answer_from_llm(self, prompt):
+        """
+        Gets an answer based on the given query and context by passing it
+        to an LLM.
+
+        :param query: The query to use.
+        :param context: Similar documents to the query used as context.
+        :return: The answer.
+        """
         answer = self.get_openai_answer(prompt)
         return answer
 
@@ -190,12 +216,9 @@ class EmbedChain:
         :param input_query: The query to use.
         :return: The answer to the query.
         """
-        result = self.collection.query(
-            query_texts=[input_query,],
-            n_results=1,
-        )
-        result_formatted = self._format_result(result)
-        answer = self.get_answer_from_llm(input_query, result_formatted[0][0].page_content)
+        context = self.retrieve_from_database(input_query)
+        prompt = self.generate_prompt(input_query, context)
+        answer = self.get_answer_from_llm(prompt)
         return answer