|
@@ -4,7 +4,7 @@ import logging
|
|
import os
|
|
import os
|
|
import sqlite3
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Any, Dict, List, Optional
|
|
|
|
|
|
+from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
from dotenv import load_dotenv
|
|
from dotenv import load_dotenv
|
|
from langchain.docstore.document import Document
|
|
from langchain.docstore.document import Document
|
|
@@ -438,7 +438,9 @@ class EmbedChain(JSONSerializable):
|
|
)
|
|
)
|
|
]
|
|
]
|
|
|
|
|
|
- def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
|
|
|
|
|
|
+ def retrieve_from_database(
|
|
|
|
+ self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
|
|
|
|
+ ) -> Union[List[Tuple[str, str, str]], List[str]]:
|
|
"""
|
|
"""
|
|
Queries the vector database based on the given input query.
|
|
Queries the vector database based on the given input query.
|
|
Gets relevant doc based on the query
|
|
Gets relevant doc based on the query
|
|
@@ -449,6 +451,8 @@ class EmbedChain(JSONSerializable):
|
|
:type config: Optional[BaseLlmConfig], optional
|
|
:type config: Optional[BaseLlmConfig], optional
|
|
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
|
|
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
|
|
:type where: _type_, optional
|
|
:type where: _type_, optional
|
|
|
|
+ :param citations: A boolean to indicate if db should fetch citation source
|
|
|
|
+ :type citations: bool
|
|
:return: List of contents of the document that matched your query
|
|
:return: List of contents of the document that matched your query
|
|
:rtype: List[str]
|
|
:rtype: List[str]
|
|
"""
|
|
"""
|
|
@@ -478,14 +482,19 @@ class EmbedChain(JSONSerializable):
|
|
n_results=query_config.number_documents,
|
|
n_results=query_config.number_documents,
|
|
where=where,
|
|
where=where,
|
|
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
|
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
|
|
|
+ citations=citations,
|
|
)
|
|
)
|
|
|
|
|
|
- if len(contexts) > 0 and isinstance(contexts[0], tuple):
|
|
|
|
- contexts = list(map(lambda x: x[0], contexts))
|
|
|
|
-
|
|
|
|
return contexts
|
|
return contexts
|
|
|
|
|
|
- def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
|
|
|
|
|
|
+ def query(
|
|
|
|
+ self,
|
|
|
|
+ input_query: str,
|
|
|
|
+ config: BaseLlmConfig = None,
|
|
|
|
+ dry_run=False,
|
|
|
|
+ where: Optional[Dict] = None,
|
|
|
|
+ **kwargs: Dict[str, Any],
|
|
|
|
+ ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
|
|
"""
|
|
"""
|
|
Queries the vector database based on the given input query.
|
|
Queries the vector database based on the given input query.
|
|
Gets relevant doc based on the query and then passes it to an
|
|
Gets relevant doc based on the query and then passes it to an
|
|
@@ -501,15 +510,31 @@ class EmbedChain(JSONSerializable):
|
|
:type dry_run: bool, optional
|
|
:type dry_run: bool, optional
|
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
:type where: Optional[Dict[str, str]], optional
|
|
:type where: Optional[Dict[str, str]], optional
|
|
- :return: The answer to the query or the dry run result
|
|
|
|
- :rtype: str
|
|
|
|
|
|
+ :param kwargs: To read more params for the query function. Ex. we use citations boolean
|
|
|
|
+ param to return context along with the answer
|
|
|
|
+ :type kwargs: Dict[str, Any]
|
|
|
|
+ :return: The answer to the query, with citations if the citation flag is True
|
|
|
|
+ or the dry run result
|
|
|
|
+ :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
|
"""
|
|
"""
|
|
- contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
|
|
|
- answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
|
|
|
|
|
+ citations = kwargs.get("citations", False)
|
|
|
|
+ contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
|
|
|
+ if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
|
|
|
+ contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
|
|
|
+ else:
|
|
|
|
+ contexts_data_for_llm_query = contexts
|
|
|
|
+
|
|
|
|
+ answer = self.llm.query(
|
|
|
|
+ input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
|
|
|
+ )
|
|
|
|
|
|
# Send anonymous telemetry
|
|
# Send anonymous telemetry
|
|
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
|
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
|
- return answer
|
|
|
|
|
|
+
|
|
|
|
+ if citations:
|
|
|
|
+ return answer, contexts
|
|
|
|
+ else:
|
|
|
|
+ return answer
|
|
|
|
|
|
def chat(
|
|
def chat(
|
|
self,
|
|
self,
|
|
@@ -517,6 +542,7 @@ class EmbedChain(JSONSerializable):
|
|
config: Optional[BaseLlmConfig] = None,
|
|
config: Optional[BaseLlmConfig] = None,
|
|
dry_run=False,
|
|
dry_run=False,
|
|
where: Optional[Dict[str, str]] = None,
|
|
where: Optional[Dict[str, str]] = None,
|
|
|
|
+ **kwargs: Dict[str, Any],
|
|
) -> str:
|
|
) -> str:
|
|
"""
|
|
"""
|
|
Queries the vector database on the given input query.
|
|
Queries the vector database on the given input query.
|
|
@@ -535,15 +561,31 @@ class EmbedChain(JSONSerializable):
|
|
:type dry_run: bool, optional
|
|
:type dry_run: bool, optional
|
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
:type where: Optional[Dict[str, str]], optional
|
|
:type where: Optional[Dict[str, str]], optional
|
|
- :return: The answer to the query or the dry run result
|
|
|
|
- :rtype: str
|
|
|
|
|
|
+ :param kwargs: To read more params for the query function. Ex. we use citations boolean
|
|
|
|
+ param to return context along with the answer
|
|
|
|
+ :type kwargs: Dict[str, Any]
|
|
|
|
+ :return: The answer to the query, with citations if the citation flag is True
|
|
|
|
+ or the dry run result
|
|
|
|
+ :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
|
"""
|
|
"""
|
|
- contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
|
|
|
- answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
|
|
|
|
|
+ citations = kwargs.get("citations", False)
|
|
|
|
+ contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
|
|
|
+ if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
|
|
|
+ contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
|
|
|
+ else:
|
|
|
|
+ contexts_data_for_llm_query = contexts
|
|
|
|
+
|
|
|
|
+ answer = self.llm.chat(
|
|
|
|
+ input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
|
|
|
+ )
|
|
|
|
+
|
|
# Send anonymous telemetry
|
|
# Send anonymous telemetry
|
|
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
|
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
|
|
|
|
|
- return answer
|
|
|
|
|
|
+ if citations:
|
|
|
|
+ return answer, contexts
|
|
|
|
+ else:
|
|
|
|
+ return answer
|
|
|
|
|
|
def set_collection_name(self, name: str):
|
|
def set_collection_name(self, name: str):
|
|
"""
|
|
"""
|