瀏覽代碼

[Bugfix] fix return type of ec chat (#995)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 年之前
父節點
當前提交
e84b5034ea
共有 2 個文件被更改,包括 5 次插入8 次删除
  1. 4 7
      embedchain/embedchain.py
  2. 1 1
      pyproject.toml

+ 4 - 7
embedchain/embedchain.py

@@ -565,7 +565,7 @@ class EmbedChain(JSONSerializable):
         dry_run=False,
         dry_run=False,
         where: Optional[Dict[str, str]] = None,
         where: Optional[Dict[str, str]] = None,
         **kwargs: Dict[str, Any],
         **kwargs: Dict[str, Any],
-    ) -> str:
+    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
         """
         """
         Queries the vector database on the given input query.
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -590,13 +590,10 @@ class EmbedChain(JSONSerializable):
         or the dry run result
         or the dry run result
         :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
         """
-        if "citations" in kwargs:
-            citations = kwargs.pop("citations")
-        else:
-            citations = False
-
+        citations = kwargs.get("citations", False)
+        db_kwargs = {key: value for key, value in kwargs.items() if key != "citations"}
         contexts = self._retrieve_from_database(
         contexts = self._retrieve_from_database(
-            input_query=input_query, config=config, where=where, citations=citations, **kwargs
+            input_query=input_query, config=config, where=where, citations=citations, **db_kwargs
         )
         )
         if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
         if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
             contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
             contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.1.27"
+version = "0.1.28"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",