|
@@ -9,8 +9,7 @@ from langchain.docstore.document import Document
|
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
|
|
|
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
|
|
|
-from embedchain.config.QueryConfig import (CODE_DOCS_PAGE_PROMPT_TEMPLATE,
|
|
|
- DEFAULT_PROMPT)
|
|
|
+from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
|
|
|
from embedchain.data_formatter import DataFormatter
|
|
|
|
|
|
gpt4all_model = None
|
|
@@ -37,6 +36,7 @@ class EmbedChain:
|
|
|
self.collection = self.config.db.collection
|
|
|
self.user_asks = []
|
|
|
self.is_code_docs_instance = False
|
|
|
+ self.online = False
|
|
|
|
|
|
def add(self, data_type, url, metadata=None, config: AddConfig = None):
|
|
|
"""
|
|
@@ -163,7 +163,10 @@ class EmbedChain:
|
|
|
contents = [result[0].page_content for result in results_formatted]
|
|
|
return contents
|
|
|
|
|
|
- def generate_prompt(self, input_query, contexts, config: QueryConfig):
|
|
|
+ def _append_search_and_context(self, context, web_search_result):
|
|
|
+ return f"{context}\nWeb Search Result: {web_search_result}"
|
|
|
+
|
|
|
+ def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs):
|
|
|
"""
|
|
|
Generates a prompt based on the given query and context, ready to be
|
|
|
passed to an LLM
|
|
@@ -175,6 +178,9 @@ class EmbedChain:
|
|
|
:return: The prompt
|
|
|
"""
|
|
|
context_string = (" | ").join(contexts)
|
|
|
+ web_search_result = kwargs.get("web_search_result", "")
|
|
|
+ if web_search_result:
|
|
|
+ context_string = self._append_search_and_context(context_string, web_search_result)
|
|
|
if not config.history:
|
|
|
prompt = config.template.substitute(context=context_string, query=input_query)
|
|
|
else:
|
|
@@ -193,6 +199,12 @@ class EmbedChain:
|
|
|
|
|
|
return self.get_llm_model_answer(prompt, config)
|
|
|
|
|
|
+ def access_search_and_get_results(self, input_query):
|
|
|
+ from langchain.tools import DuckDuckGoSearchRun
|
|
|
+ search = DuckDuckGoSearchRun()
|
|
|
+ logging.info(f"Access search to get answers for {input_query}")
|
|
|
+ return search.run(input_query)
|
|
|
+
|
|
|
def query(self, input_query, config: QueryConfig = None):
|
|
|
"""
|
|
|
Queries the vector database based on the given input query.
|
|
@@ -209,8 +221,11 @@ class EmbedChain:
|
|
|
if self.is_code_docs_instance:
|
|
|
config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
|
|
config.number_documents = 5
|
|
|
+ k = {}
|
|
|
+ if self.online:
|
|
|
+ k["web_search_result"] = self.access_search_and_get_results(input_query)
|
|
|
contexts = self.retrieve_from_database(input_query, config)
|
|
|
- prompt = self.generate_prompt(input_query, contexts, config)
|
|
|
+ prompt = self.generate_prompt(input_query, contexts, config, **k)
|
|
|
logging.info(f"Prompt: {prompt}")
|
|
|
|
|
|
answer = self.get_answer_from_llm(prompt, config)
|
|
@@ -245,7 +260,10 @@ class EmbedChain:
|
|
|
if self.is_code_docs_instance:
|
|
|
config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
|
|
config.number_documents = 5
|
|
|
- contexts = self.retrieve_from_database(input_query, config)
|
|
|
+ k = {}
|
|
|
+ if self.online:
|
|
|
+ k["web_search_result"] = self.access_search_and_get_results(input_query)
|
|
|
+ contexts = self.retrieve_from_database(input_query, config, **k)
|
|
|
|
|
|
global memory
|
|
|
chat_history = memory.load_memory_variables({})["history"]
|
|
@@ -253,7 +271,7 @@ class EmbedChain:
|
|
|
if chat_history:
|
|
|
config.set_history(chat_history)
|
|
|
|
|
|
- prompt = self.generate_prompt(input_query, contexts, config)
|
|
|
+ prompt = self.generate_prompt(input_query, contexts, config, **k)
|
|
|
logging.info(f"Prompt: {prompt}")
|
|
|
answer = self.get_answer_from_llm(prompt, config)
|
|
|
|