Bläddra i källkod

[Feature] Add citations flag in query and chat functions of App to return context along with the answer (#859)

Deven Patel 1 år sedan
förälder
incheckning
930280f4ce

+ 58 - 16
embedchain/embedchain.py

@@ -4,7 +4,7 @@ import logging
 import os
 import sqlite3
 from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
@@ -438,7 +438,9 @@ class EmbedChain(JSONSerializable):
             )
         ]
 
-    def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
+    def retrieve_from_database(
+        self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
@@ -449,6 +451,8 @@ class EmbedChain(JSONSerializable):
         :type config: Optional[BaseLlmConfig], optional
         :param where: A dictionary of key-value pairs to filter the database results, defaults to None
         :type where: _type_, optional
+        :param citations: A boolean to indicate if db should fetch citation source
+        :type citations: bool
         :return: List of contents of the document that matched your query
         :rtype: List[str]
         """
@@ -478,14 +482,19 @@ class EmbedChain(JSONSerializable):
             n_results=query_config.number_documents,
             where=where,
             skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
+            citations=citations,
         )
 
-        if len(contexts) > 0 and isinstance(contexts[0], tuple):
-            contexts = list(map(lambda x: x[0], contexts))
-
         return contexts
 
-    def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
+    def query(
+        self,
+        input_query: str,
+        config: BaseLlmConfig = None,
+        dry_run=False,
+        where: Optional[Dict] = None,
+        **kwargs: Dict[str, Any],
+    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -501,15 +510,31 @@ class EmbedChain(JSONSerializable):
         :type dry_run: bool, optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :type where: Optional[Dict[str, str]], optional
-        :return: The answer to the query or the dry run result
-        :rtype: str
+        :param kwargs: To read more params for the query function. Ex. we use citations boolean
+        param to return context along with the answer
+        :type kwargs: Dict[str, Any]
+        :return: The answer to the query, with citations if the citation flag is True
+        or the dry run result
+        :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
-        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
-        answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
+        citations = kwargs.get("citations", False)
+        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
+        if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
+            contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
+        else:
+            contexts_data_for_llm_query = contexts
+
+        answer = self.llm.query(
+            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+        )
 
         # Send anonymous telemetry
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
-        return answer
+
+        if citations:
+            return answer, contexts
+        else:
+            return answer
 
     def chat(
         self,
@@ -517,6 +542,7 @@ class EmbedChain(JSONSerializable):
         config: Optional[BaseLlmConfig] = None,
         dry_run=False,
         where: Optional[Dict[str, str]] = None,
+        **kwargs: Dict[str, Any],
     ) -> str:
         """
         Queries the vector database on the given input query.
@@ -535,15 +561,31 @@ class EmbedChain(JSONSerializable):
         :type dry_run: bool, optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :type where: Optional[Dict[str, str]], optional
-        :return: The answer to the query or the dry run result
-        :rtype: str
+        :param kwargs: To read more params for the query function. Ex. we use citations boolean
+        param to return context along with the answer
+        :type kwargs: Dict[str, Any]
+        :return: The answer to the query, with citations if the citation flag is True
+        or the dry run result
+        :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
-        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
-        answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
+        citations = kwargs.get("citations", False)
+        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
+        if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
+            contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
+        else:
+            contexts_data_for_llm_query = contexts
+
+        answer = self.llm.chat(
+            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+        )
+
         # Send anonymous telemetry
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
 
-        return answer
+        if citations:
+            return answer, contexts
+        else:
+            return answer
 
     def set_collection_name(self, name: str):
         """

+ 1 - 0
embedchain/pipeline.py

@@ -234,6 +234,7 @@ class Pipeline(EmbedChain):
                 n_results=num_documents,
                 where=where,
                 skip_embedding=False,
+                citations=True,
             )
             result = []
             for c in context:

+ 20 - 9
embedchain/vectordb/chroma.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from chromadb import Collection, QueryResult
 from langchain.docstore.document import Document
@@ -192,8 +192,13 @@ class ChromaDB(BaseVectorDB):
         ]
 
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         Query contents from vector database based on vector similarity
 
@@ -205,9 +210,12 @@ class ChromaDB(BaseVectorDB):
         :type where: Dict[str, Any]
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
         :raises InvalidDimensionException: Dimensions do not match.
-        :return: The content of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         try:
             if skip_embedding:
@@ -236,10 +244,13 @@ class ChromaDB(BaseVectorDB):
         contexts = []
         for result in results_formatted:
             context = result[0].page_content
-            metadata = result[0].metadata
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            contexts.append((context, source, doc_id))
+            if citations:
+                metadata = result[0].metadata
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append((context, source, doc_id))
+            else:
+                contexts.append(context)
         return contexts
 
     def set_collection_name(self, name: str):

+ 22 - 11
embedchain/vectordb/elasticsearch.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 try:
     from elasticsearch import Elasticsearch
@@ -136,8 +136,13 @@ class ElasticsearchDB(BaseVectorDB):
         self.client.indices.refresh(index=self._get_index())
 
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector data base based on vector similarity
 
@@ -150,8 +155,11 @@ class ElasticsearchDB(BaseVectorDB):
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
         :return: The context of the document that matched your query, url of the source, doc_id
-
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         if skip_embedding:
             query_vector = input_query
@@ -175,14 +183,17 @@ class ElasticsearchDB(BaseVectorDB):
         _source = ["text", "metadata.url", "metadata.doc_id"]
         response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
         docs = response["hits"]["hits"]
-        contents = []
+        contexts = []
         for doc in docs:
             context = doc["_source"]["text"]
-            metadata = doc["_source"]["metadata"]
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            contents.append(tuple((context, source, doc_id)))
-        return contents
+            if citations:
+                metadata = doc["_source"]["metadata"]
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
     def set_collection_name(self, name: str):
         """

+ 21 - 10
embedchain/vectordb/opensearch.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Dict, List, Optional, Set, Tuple
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 try:
     from opensearchpy import OpenSearch
@@ -146,8 +146,13 @@ class OpenSearchDB(BaseVectorDB):
         self.client.indices.refresh(index=self._get_index())
 
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector data base based on vector similarity
 
@@ -159,8 +164,11 @@ class OpenSearchDB(BaseVectorDB):
         :type where: Dict[str, any]
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
-        :return: The content of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
         embeddings = OpenAIEmbeddings()
@@ -188,13 +196,16 @@ class OpenSearchDB(BaseVectorDB):
             k=n_results,
         )
 
-        contents = []
+        contexts = []
         for doc in docs:
             context = doc.page_content
-            source = doc.metadata["url"]
-            doc_id = doc.metadata["doc_id"]
-            contents.append(tuple((context, source, doc_id)))
-        return contents
+            if citations:
+                source = doc.metadata["url"]
+                doc_id = doc.metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
     def set_collection_name(self, name: str):
         """

+ 21 - 10
embedchain/vectordb/pinecone.py

@@ -1,5 +1,5 @@
 import os
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 try:
     import pinecone
@@ -119,8 +119,13 @@ class PineconeDB(BaseVectorDB):
             self.client.upsert(docs[i : i + self.BATCH_SIZE])
 
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -131,22 +136,28 @@ class PineconeDB(BaseVectorDB):
         :type where: Dict[str, any]
         :param skip_embedding: Optional. if True, input_query is already embedded
         :type skip_embedding: bool
-        :return: The content of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
         else:
             query_vector = input_query
         data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
-        contents = []
+        contexts = []
         for doc in data["matches"]:
             metadata = doc["metadata"]
             context = metadata["text"]
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            contents.append(tuple((context, source, doc_id)))
-        return contents
+            if citations:
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
     def set_collection_name(self, name: str):
         """

+ 22 - 11
embedchain/vectordb/qdrant.py

@@ -1,7 +1,7 @@
 import copy
 import os
 import uuid
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 try:
     from qdrant_client import QdrantClient
@@ -161,8 +161,13 @@ class QdrantDB(BaseVectorDB):
             )
 
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -174,8 +179,11 @@ class QdrantDB(BaseVectorDB):
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         generated or not
         :type skip_embedding: bool
-        :return: The context of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
@@ -202,14 +210,17 @@ class QdrantDB(BaseVectorDB):
             limit=n_results,
         )
 
-        response = []
+        contexts = []
         for result in results:
             context = result.payload["text"]
-            metadata = result.payload["metadata"]
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            response.append(tuple((context, source, doc_id)))
-        return response
+            if citations:
+                metadata = result.payload["metadata"]
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
     def count(self) -> int:
         response = self.client.get_collection(collection_name=self.collection_name)

+ 35 - 11
embedchain/vectordb/weaviate.py

@@ -1,6 +1,6 @@
 import copy
 import os
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 try:
     import weaviate
@@ -58,10 +58,14 @@ class WeaviateDB(BaseVectorDB):
             raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
 
         self.index_name = self._get_index_name()
-        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
+        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
         if not self.client.schema.exists(self.index_name):
             # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
             # The none vectorizer is crucial as we have our own custom embedding function
+            """
+            TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
+            Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
+            """
             class_obj = {
                 "classes": [
                     {
@@ -106,10 +110,6 @@ class WeaviateDB(BaseVectorDB):
                                 "name": "app_id",
                                 "dataType": ["text"],
                             },
-                            {
-                                "name": "text",
-                                "dataType": ["text"],
-                            },
                         ],
                     },
                 ]
@@ -195,8 +195,13 @@ class WeaviateDB(BaseVectorDB):
                 batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
 
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -208,15 +213,23 @@ class WeaviateDB(BaseVectorDB):
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         generated or not
         :type skip_embedding: bool
-        :return: The context of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
         else:
             query_vector = input_query
+
         keys = set(where.keys() if where is not None else set())
         data_fields = ["text"]
+
+        if citations:
+            data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
+
         if len(keys.intersection(self.metadata_keys)) != 0:
             weaviate_where_operands = []
             for key in keys:
@@ -247,7 +260,18 @@ class WeaviateDB(BaseVectorDB):
                 .with_limit(n_results)
                 .do()
             )
-        contexts = results["data"]["Get"].get(self.index_name)
+
+        docs = results["data"]["Get"].get(self.index_name)
+        contexts = []
+        for doc in docs:
+            context = doc["text"]
+            if citations:
+                metadata = doc["metadata"][0]
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append((context, source, doc_id))
+            else:
+                contexts.append(context)
         return contexts
 
     def set_collection_name(self, name: str):

+ 21 - 10
embedchain/vectordb/zilliz.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 from embedchain.config import ZillizDBConfig
 from embedchain.helper.json_serializable import register_deserializable
@@ -127,8 +127,13 @@ class ZillizVectorDB(BaseVectorDB):
         self.client.flush(self.config.collection_name)
 
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         Query contents from vector data base based on vector similarity
 
@@ -139,8 +144,11 @@ class ZillizVectorDB(BaseVectorDB):
         :param where: to filter data
         :type where: str
         :raises InvalidDimensionException: Dimensions do not match.
-        :return: The context of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
 
         if self.collection.is_empty:
@@ -170,14 +178,17 @@ class ZillizVectorDB(BaseVectorDB):
                 output_fields=output_fields,
             )
 
-        doc_list = []
+        contexts = []
         for query in query_result:
             data = query[0]["entity"]
             context = data["text"]
-            source = data["url"]
-            doc_id = data["doc_id"]
-            doc_list.append(tuple((context, source, doc_id)))
-        return doc_list
+            if citations:
+                source = data["url"]
+                doc_id = data["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
     def count(self) -> int:
         """

+ 1 - 1
poetry.lock

@@ -7141,4 +7141,4 @@ whatsapp = ["flask", "twilio"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.9,<3.13"
-content-hash = "0b83ba3fd2485b3b4aa3c6a7534b214378d349538f7eb63c65768aafecdfad60"
+content-hash = "0b83ba3fd2485b3b4aa3c6a7534b214378d349538f7eb63c65768aafecdfad60"

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.0.88"
+version = "0.0.89"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 16 - 6
tests/vectordb/test_chroma_db.py

@@ -163,10 +163,12 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
 
     assert data == expected_value
 
-    data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
-    expected_value = [("document", "url_1", "doc_id_1")]
+    data_without_citations = app_with_settings.db.query(
+        input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
+    )
+    expected_value_without_citations = ["document"]
+    assert data_without_citations == expected_value_without_citations
 
-    assert data == expected_value
     app_with_settings.db.reset()
 
 
@@ -326,8 +328,16 @@ def test_chroma_db_collection_query(app_with_settings):
 
     assert app_with_settings.db.count() == 2
 
-    data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
-    expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
+    data_without_citations = app_with_settings.db.query(
+        input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
+    )
+    expected_value_without_citations = ["document", "document2"]
+    assert data_without_citations == expected_value_without_citations
+
+    data_with_citations = app_with_settings.db.query(
+        input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
+    )
+    expected_value_with_citations = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
+    assert data_with_citations == expected_value_with_citations
 
-    assert data == expected_value
     app_with_settings.db.reset()

+ 10 - 8
tests/vectordb/test_elasticsearch_db.py

@@ -60,12 +60,16 @@ class TestEsDB(unittest.TestCase):
 
         # Query the database for the documents that are most similar to the query "This is a document".
         query = ["This is a document"]
-        results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
+        results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
+        expected_results_without_citations = ["This is a document.", "This is another document."]
+        self.assertEqual(results_without_citations, expected_results_without_citations)
 
-        # Assert that the results are correct.
-        self.assertEqual(
-            results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
-        )
+        results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
+        expected_results_with_citations = [
+            ("This is a document.", "url_1", "doc_id_1"),
+            ("This is another document.", "url_2", "doc_id_2"),
+        ]
+        self.assertEqual(results_with_citations, expected_results_with_citations)
 
     @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
     def test_query_with_skip_embedding(self, mock_client):
@@ -111,9 +115,7 @@ class TestEsDB(unittest.TestCase):
         results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
 
         # Assert that the results are correct.
-        self.assertEqual(
-            results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
-        )
+        self.assertEqual(results, ["This is a document.", "This is another document."])
 
     def test_init_without_url(self):
         # Make sure it's not loaded from env

+ 0 - 4
tests/vectordb/test_weaviate.py

@@ -75,10 +75,6 @@ class TestWeaviateDb(unittest.TestCase):
                             "name": "app_id",
                             "dataType": ["text"],
                         },
-                        {
-                            "name": "text",
-                            "dataType": ["text"],
-                        },
                     ],
                 },
             ]

+ 30 - 4
tests/vectordb/test_zilliz_db.py

@@ -129,7 +129,7 @@ class TestZillizDBCollection:
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
 
             # Assert that MilvusClient.search was called with the correct parameters
-            mock_search.assert_called_once_with(
+            mock_search.assert_called_with(
                 collection_name=mock_config.collection_name,
                 data=["query_text"],
                 limit=1,
@@ -137,7 +137,20 @@ class TestZillizDBCollection:
             )
 
             # Assert that the query result matches the expected result
-            assert query_result == [("result_doc", "url_1", "doc_id_1")]
+            assert query_result == ["result_doc"]
+
+            query_result_with_citations = zilliz_db.query(
+                input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
+            )
+
+            mock_search.assert_called_with(
+                collection_name=mock_config.collection_name,
+                data=["query_text"],
+                limit=1,
+                output_fields=["text", "url", "doc_id"],
+            )
+
+            assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
 
     @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
     @patch("embedchain.vectordb.zilliz.connections", autospec=True)
@@ -168,7 +181,7 @@ class TestZillizDBCollection:
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
 
             # Assert that MilvusClient.search was called with the correct parameters
-            mock_search.assert_called_once_with(
+            mock_search.assert_called_with(
                 collection_name=mock_config.collection_name,
                 data=["query_vector"],
                 limit=1,
@@ -176,4 +189,17 @@ class TestZillizDBCollection:
             )
 
             # Assert that the query result matches the expected result
-            assert query_result == [("result_doc", "url_1", "doc_id_1")]
+            assert query_result == ["result_doc"]
+
+            query_result_with_citations = zilliz_db.query(
+                input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
+            )
+
+            mock_search.assert_called_with(
+                collection_name=mock_config.collection_name,
+                data=["query_vector"],
+                limit=1,
+                output_fields=["text", "url", "doc_id"],
+            )
+
+            assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]