Browse Source

[Bugfix] fix return type of ec chat (#995)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 2 years ago
parent
commit
e84b5034ea
2 changed files with 5 additions and 8 deletions
  1. 4 7
      embedchain/embedchain.py
  2. 1 1
      pyproject.toml

+ 4 - 7
embedchain/embedchain.py

@@ -565,7 +565,7 @@ class EmbedChain(JSONSerializable):
         dry_run=False,
         dry_run=False,
         where: Optional[Dict[str, str]] = None,
         where: Optional[Dict[str, str]] = None,
         **kwargs: Dict[str, Any],
         **kwargs: Dict[str, Any],
-    ) -> str:
+    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
         """
         """
         Queries the vector database on the given input query.
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -590,13 +590,10 @@ class EmbedChain(JSONSerializable):
         or the dry run result
         or the dry run result
         :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
         """
-        if "citations" in kwargs:
-            citations = kwargs.pop("citations")
-        else:
-            citations = False
-
+        citations = kwargs.get("citations", False)
+        db_kwargs = {key: value for key, value in kwargs.items() if key != "citations"}
         contexts = self._retrieve_from_database(
         contexts = self._retrieve_from_database(
-            input_query=input_query, config=config, where=where, citations=citations, **kwargs
+            input_query=input_query, config=config, where=where, citations=citations, **db_kwargs
         )
         )
         if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
         if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
             contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
             contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.1.27"
+version = "0.1.28"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",