|
@@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable):
|
|
"""
|
|
"""
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
- def retrieve_from_database(self, input_query, config: QueryConfig):
|
|
|
|
|
|
+ def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
|
|
"""
|
|
"""
|
|
Queries the vector database based on the given input query.
|
|
Queries the vector database based on the given input query.
|
|
Gets relevant doc based on the query
|
|
Gets relevant doc based on the query
|
|
|
|
|
|
:param input_query: The query to use.
|
|
:param input_query: The query to use.
|
|
:param config: The query configuration.
|
|
:param config: The query configuration.
|
|
|
|
+ :param where: Optional. A dictionary of key-value pairs to filter the database results.
|
|
:return: The content of the document that matched your query.
|
|
:return: The content of the document that matched your query.
|
|
"""
|
|
"""
|
|
- where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter
|
|
|
|
|
|
+
|
|
|
|
+ if where is not None:
|
|
|
|
+ where = where
|
|
|
|
+ elif config is not None and config.where is not None:
|
|
|
|
+ where = config.where
|
|
|
|
+ else:
|
|
|
|
+ where = {}
|
|
|
|
+
|
|
|
|
+ if self.config.id is not None:
|
|
|
|
+ where.update({"app_id": self.config.id})
|
|
|
|
+
|
|
contents = self.db.query(
|
|
contents = self.db.query(
|
|
input_query=input_query,
|
|
input_query=input_query,
|
|
n_results=config.number_documents,
|
|
n_results=config.number_documents,
|
|
@@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable):
|
|
logging.info(f"Access search to get answers for {input_query}")
|
|
logging.info(f"Access search to get answers for {input_query}")
|
|
return search.run(input_query)
|
|
return search.run(input_query)
|
|
|
|
|
|
- def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
|
|
|
|
|
+ def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
|
|
"""
|
|
"""
|
|
Queries the vector database based on the given input query.
|
|
Queries the vector database based on the given input query.
|
|
Gets relevant doc based on the query and then passes it to an
|
|
Gets relevant doc based on the query and then passes it to an
|
|
@@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable):
|
|
by the vector database's doc retrieval.
|
|
by the vector database's doc retrieval.
|
|
The only thing the dry run does not consider is the cut-off due to
|
|
The only thing the dry run does not consider is the cut-off due to
|
|
the `max_tokens` parameter.
|
|
the `max_tokens` parameter.
|
|
|
|
+ :param where: Optional. A dictionary of key-value pairs to filter the database results.
|
|
:return: The answer to the query.
|
|
:return: The answer to the query.
|
|
"""
|
|
"""
|
|
if config is None:
|
|
if config is None:
|
|
@@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable):
|
|
k = {}
|
|
k = {}
|
|
if self.online:
|
|
if self.online:
|
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
|
- contexts = self.retrieve_from_database(input_query, config)
|
|
|
|
|
|
+ contexts = self.retrieve_from_database(input_query, config, where)
|
|
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
|
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
|
logging.info(f"Prompt: {prompt}")
|
|
logging.info(f"Prompt: {prompt}")
|
|
|
|
|
|
@@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable):
|
|
yield chunk
|
|
yield chunk
|
|
logging.info(f"Answer: {streamed_answer}")
|
|
logging.info(f"Answer: {streamed_answer}")
|
|
|
|
|
|
- def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
|
|
|
|
|
+ def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
|
"""
|
|
"""
|
|
Queries the vector database on the given input query.
|
|
Queries the vector database on the given input query.
|
|
Gets relevant doc based on the query and then passes it to an
|
|
Gets relevant doc based on the query and then passes it to an
|
|
@@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable):
|
|
by the vector database's doc retrieval.
|
|
by the vector database's doc retrieval.
|
|
The only thing the dry run does not consider is the cut-off due to
|
|
The only thing the dry run does not consider is the cut-off due to
|
|
the `max_tokens` parameter.
|
|
the `max_tokens` parameter.
|
|
|
|
+ :param where: Optional. A dictionary of key-value pairs to filter the database results.
|
|
:return: The answer to the query.
|
|
:return: The answer to the query.
|
|
"""
|
|
"""
|
|
if config is None:
|
|
if config is None:
|
|
@@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable):
|
|
k = {}
|
|
k = {}
|
|
if self.online:
|
|
if self.online:
|
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
|
- contexts = self.retrieve_from_database(input_query, config)
|
|
|
|
|
|
+ contexts = self.retrieve_from_database(input_query, config, where)
|
|
|
|
|
|
chat_history = self.memory.load_memory_variables({})["history"]
|
|
chat_history = self.memory.load_memory_variables({})["history"]
|
|
|
|
|