Browse Source

feat: Add browse the internet or online functionality. (#291)

Taranjeet Singh 2 years ago
parent
commit
81c8cc62a2
1 changed files with 24 additions and 6 deletions
  1. 24 6
      embedchain/embedchain.py

+ 24 - 6
embedchain/embedchain.py

@@ -9,8 +9,7 @@ from langchain.docstore.document import Document
 from langchain.memory import ConversationBufferMemory
 
 from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
-from embedchain.config.QueryConfig import (CODE_DOCS_PAGE_PROMPT_TEMPLATE,
-                                           DEFAULT_PROMPT)
+from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
 from embedchain.data_formatter import DataFormatter
 
 gpt4all_model = None
@@ -37,6 +36,7 @@ class EmbedChain:
         self.collection = self.config.db.collection
         self.user_asks = []
         self.is_code_docs_instance = False
+        self.online = False
 
     def add(self, data_type, url, metadata=None, config: AddConfig = None):
         """
@@ -163,7 +163,10 @@ class EmbedChain:
         contents = [result[0].page_content for result in results_formatted]
         return contents
 
-    def generate_prompt(self, input_query, contexts, config: QueryConfig):
+    def _append_search_and_context(self, context, web_search_result):
+        return f"{context}\nWeb Search Result: {web_search_result}"
+
+    def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs):
         """
         Generates a prompt based on the given query and context, ready to be
         passed to an LLM
@@ -175,6 +178,9 @@ class EmbedChain:
         :return: The prompt
         """
         context_string = (" | ").join(contexts)
+        web_search_result = kwargs.get("web_search_result", "")
+        if web_search_result:
+            context_string = self._append_search_and_context(context_string, web_search_result)
         if not config.history:
             prompt = config.template.substitute(context=context_string, query=input_query)
         else:
@@ -193,6 +199,12 @@ class EmbedChain:
 
         return self.get_llm_model_answer(prompt, config)
 
+    def access_search_and_get_results(self, input_query):
+        from langchain.tools import DuckDuckGoSearchRun
+        search = DuckDuckGoSearchRun()
+        logging.info(f"Access search to get answers for {input_query}")
+        return search.run(input_query)
+
     def query(self, input_query, config: QueryConfig = None):
         """
         Queries the vector database based on the given input query.
@@ -209,8 +221,11 @@ class EmbedChain:
         if self.is_code_docs_instance:
             config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
             config.number_documents = 5
+        k = {}
+        if self.online:
+            k["web_search_result"] = self.access_search_and_get_results(input_query)
         contexts = self.retrieve_from_database(input_query, config)
-        prompt = self.generate_prompt(input_query, contexts, config)
+        prompt = self.generate_prompt(input_query, contexts, config, **k)
         logging.info(f"Prompt: {prompt}")
 
         answer = self.get_answer_from_llm(prompt, config)
@@ -245,7 +260,10 @@ class EmbedChain:
         if self.is_code_docs_instance:
             config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
             config.number_documents = 5
-        contexts = self.retrieve_from_database(input_query, config)
+        k = {}
+        if self.online:
+            k["web_search_result"] = self.access_search_and_get_results(input_query)
+        contexts = self.retrieve_from_database(input_query, config, **k)
 
         global memory
         chat_history = memory.load_memory_variables({})["history"]
@@ -253,7 +271,7 @@ class EmbedChain:
         if chat_history:
             config.set_history(chat_history)
 
-        prompt = self.generate_prompt(input_query, contexts, config)
+        prompt = self.generate_prompt(input_query, contexts, config, **k)
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)