Pārlūkot izejas kodu

[Feature] Add citations flag in query and chat functions of App to return context along with the answer (#859)

Deven Patel 1 gadu atpakaļ
vecāks
revīzija
930280f4ce

+ 58 - 16
embedchain/embedchain.py

@@ -4,7 +4,7 @@ import logging
 import os
 import os
 import sqlite3
 import sqlite3
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
@@ -438,7 +438,9 @@ class EmbedChain(JSONSerializable):
             )
             )
         ]
         ]
 
 
-    def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
+    def retrieve_from_database(
+        self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         Queries the vector database based on the given input query.
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
         Gets relevant doc based on the query
@@ -449,6 +451,8 @@ class EmbedChain(JSONSerializable):
         :type config: Optional[BaseLlmConfig], optional
         :type config: Optional[BaseLlmConfig], optional
         :param where: A dictionary of key-value pairs to filter the database results, defaults to None
         :param where: A dictionary of key-value pairs to filter the database results, defaults to None
         :type where: _type_, optional
         :type where: _type_, optional
+        :param citations: A boolean to indicate if db should fetch citation source
+        :type citations: bool
         :return: List of contents of the document that matched your query
         :return: List of contents of the document that matched your query
         :rtype: List[str]
         :rtype: List[str]
         """
         """
@@ -478,14 +482,19 @@ class EmbedChain(JSONSerializable):
             n_results=query_config.number_documents,
             n_results=query_config.number_documents,
             where=where,
             where=where,
             skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
             skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
+            citations=citations,
         )
         )
 
 
-        if len(contexts) > 0 and isinstance(contexts[0], tuple):
-            contexts = list(map(lambda x: x[0], contexts))
-
         return contexts
         return contexts
 
 
-    def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
+    def query(
+        self,
+        input_query: str,
+        config: BaseLlmConfig = None,
+        dry_run=False,
+        where: Optional[Dict] = None,
+        **kwargs: Dict[str, Any],
+    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
         """
         """
         Queries the vector database based on the given input query.
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -501,15 +510,31 @@ class EmbedChain(JSONSerializable):
         :type dry_run: bool, optional
         :type dry_run: bool, optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :type where: Optional[Dict[str, str]], optional
         :type where: Optional[Dict[str, str]], optional
-        :return: The answer to the query or the dry run result
-        :rtype: str
+        :param kwargs: To read more params for the query function. Ex. we use citations boolean
+        param to return context along with the answer
+        :type kwargs: Dict[str, Any]
+        :return: The answer to the query, with citations if the citation flag is True
+        or the dry run result
+        :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
         """
-        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
-        answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
+        citations = kwargs.get("citations", False)
+        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
+        if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
+            contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
+        else:
+            contexts_data_for_llm_query = contexts
+
+        answer = self.llm.query(
+            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+        )
 
 
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
-        return answer
+
+        if citations:
+            return answer, contexts
+        else:
+            return answer
 
 
     def chat(
     def chat(
         self,
         self,
@@ -517,6 +542,7 @@ class EmbedChain(JSONSerializable):
         config: Optional[BaseLlmConfig] = None,
         config: Optional[BaseLlmConfig] = None,
         dry_run=False,
         dry_run=False,
         where: Optional[Dict[str, str]] = None,
         where: Optional[Dict[str, str]] = None,
+        **kwargs: Dict[str, Any],
     ) -> str:
     ) -> str:
         """
         """
         Queries the vector database on the given input query.
         Queries the vector database on the given input query.
@@ -535,15 +561,31 @@ class EmbedChain(JSONSerializable):
         :type dry_run: bool, optional
         :type dry_run: bool, optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
         :type where: Optional[Dict[str, str]], optional
         :type where: Optional[Dict[str, str]], optional
-        :return: The answer to the query or the dry run result
-        :rtype: str
+        :param kwargs: To read more params for the query function. Ex. we use citations boolean
+        param to return context along with the answer
+        :type kwargs: Dict[str, Any]
+        :return: The answer to the query, with citations if the citation flag is True
+        or the dry run result
+        :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
         """
-        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
-        answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
+        citations = kwargs.get("citations", False)
+        contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
+        if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
+            contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
+        else:
+            contexts_data_for_llm_query = contexts
+
+        answer = self.llm.chat(
+            input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+        )
+
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
 
 
-        return answer
+        if citations:
+            return answer, contexts
+        else:
+            return answer
 
 
     def set_collection_name(self, name: str):
     def set_collection_name(self, name: str):
         """
         """

+ 1 - 0
embedchain/pipeline.py

@@ -234,6 +234,7 @@ class Pipeline(EmbedChain):
                 n_results=num_documents,
                 n_results=num_documents,
                 where=where,
                 where=where,
                 skip_embedding=False,
                 skip_embedding=False,
+                citations=True,
             )
             )
             result = []
             result = []
             for c in context:
             for c in context:

+ 20 - 9
embedchain/vectordb/chroma.py

@@ -1,5 +1,5 @@
 import logging
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 
 from chromadb import Collection, QueryResult
 from chromadb import Collection, QueryResult
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
@@ -192,8 +192,13 @@ class ChromaDB(BaseVectorDB):
         ]
         ]
 
 
     def query(
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         Query contents from vector database based on vector similarity
         Query contents from vector database based on vector similarity
 
 
@@ -205,9 +210,12 @@ class ChromaDB(BaseVectorDB):
         :type where: Dict[str, Any]
         :type where: Dict[str, Any]
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
         :type skip_embedding: bool
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
         :raises InvalidDimensionException: Dimensions do not match.
         :raises InvalidDimensionException: Dimensions do not match.
-        :return: The content of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         """
         try:
         try:
             if skip_embedding:
             if skip_embedding:
@@ -236,10 +244,13 @@ class ChromaDB(BaseVectorDB):
         contexts = []
         contexts = []
         for result in results_formatted:
         for result in results_formatted:
             context = result[0].page_content
             context = result[0].page_content
-            metadata = result[0].metadata
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            contexts.append((context, source, doc_id))
+            if citations:
+                metadata = result[0].metadata
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append((context, source, doc_id))
+            else:
+                contexts.append(context)
         return contexts
         return contexts
 
 
     def set_collection_name(self, name: str):
     def set_collection_name(self, name: str):

+ 22 - 11
embedchain/vectordb/elasticsearch.py

@@ -1,5 +1,5 @@
 import logging
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 
 try:
 try:
     from elasticsearch import Elasticsearch
     from elasticsearch import Elasticsearch
@@ -136,8 +136,13 @@ class ElasticsearchDB(BaseVectorDB):
         self.client.indices.refresh(index=self._get_index())
         self.client.indices.refresh(index=self._get_index())
 
 
     def query(
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         query contents from vector data base based on vector similarity
         query contents from vector data base based on vector similarity
 
 
@@ -150,8 +155,11 @@ class ElasticsearchDB(BaseVectorDB):
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
         :type skip_embedding: bool
         :return: The context of the document that matched your query, url of the source, doc_id
         :return: The context of the document that matched your query, url of the source, doc_id
-
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         """
         if skip_embedding:
         if skip_embedding:
             query_vector = input_query
             query_vector = input_query
@@ -175,14 +183,17 @@ class ElasticsearchDB(BaseVectorDB):
         _source = ["text", "metadata.url", "metadata.doc_id"]
         _source = ["text", "metadata.url", "metadata.doc_id"]
         response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
         response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
         docs = response["hits"]["hits"]
         docs = response["hits"]["hits"]
-        contents = []
+        contexts = []
         for doc in docs:
         for doc in docs:
             context = doc["_source"]["text"]
             context = doc["_source"]["text"]
-            metadata = doc["_source"]["metadata"]
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            contents.append(tuple((context, source, doc_id)))
-        return contents
+            if citations:
+                metadata = doc["_source"]["metadata"]
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
 
     def set_collection_name(self, name: str):
     def set_collection_name(self, name: str):
         """
         """

+ 21 - 10
embedchain/vectordb/opensearch.py

@@ -1,5 +1,5 @@
 import logging
 import logging
-from typing import Dict, List, Optional, Set, Tuple
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 
 try:
 try:
     from opensearchpy import OpenSearch
     from opensearchpy import OpenSearch
@@ -146,8 +146,13 @@ class OpenSearchDB(BaseVectorDB):
         self.client.indices.refresh(index=self._get_index())
         self.client.indices.refresh(index=self._get_index())
 
 
     def query(
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         query contents from vector data base based on vector similarity
         query contents from vector data base based on vector similarity
 
 
@@ -159,8 +164,11 @@ class OpenSearchDB(BaseVectorDB):
         :type where: Dict[str, any]
         :type where: Dict[str, any]
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
         :type skip_embedding: bool
-        :return: The content of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         """
         # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
         # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
         embeddings = OpenAIEmbeddings()
         embeddings = OpenAIEmbeddings()
@@ -188,13 +196,16 @@ class OpenSearchDB(BaseVectorDB):
             k=n_results,
             k=n_results,
         )
         )
 
 
-        contents = []
+        contexts = []
         for doc in docs:
         for doc in docs:
             context = doc.page_content
             context = doc.page_content
-            source = doc.metadata["url"]
-            doc_id = doc.metadata["doc_id"]
-            contents.append(tuple((context, source, doc_id)))
-        return contents
+            if citations:
+                source = doc.metadata["url"]
+                doc_id = doc.metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
 
     def set_collection_name(self, name: str):
     def set_collection_name(self, name: str):
         """
         """

+ 21 - 10
embedchain/vectordb/pinecone.py

@@ -1,5 +1,5 @@
 import os
 import os
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 
 try:
 try:
     import pinecone
     import pinecone
@@ -119,8 +119,13 @@ class PineconeDB(BaseVectorDB):
             self.client.upsert(docs[i : i + self.BATCH_SIZE])
             self.client.upsert(docs[i : i + self.BATCH_SIZE])
 
 
     def query(
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         query contents from vector database based on vector similarity
         query contents from vector database based on vector similarity
         :param input_query: list of query string
         :param input_query: list of query string
@@ -131,22 +136,28 @@ class PineconeDB(BaseVectorDB):
         :type where: Dict[str, any]
         :type where: Dict[str, any]
         :param skip_embedding: Optional. if True, input_query is already embedded
         :param skip_embedding: Optional. if True, input_query is already embedded
         :type skip_embedding: bool
         :type skip_embedding: bool
-        :return: The content of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         """
         if not skip_embedding:
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
             query_vector = self.embedder.embedding_fn([input_query])[0]
         else:
         else:
             query_vector = input_query
             query_vector = input_query
         data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
         data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
-        contents = []
+        contexts = []
         for doc in data["matches"]:
         for doc in data["matches"]:
             metadata = doc["metadata"]
             metadata = doc["metadata"]
             context = metadata["text"]
             context = metadata["text"]
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            contents.append(tuple((context, source, doc_id)))
-        return contents
+            if citations:
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
 
     def set_collection_name(self, name: str):
     def set_collection_name(self, name: str):
         """
         """

+ 22 - 11
embedchain/vectordb/qdrant.py

@@ -1,7 +1,7 @@
 import copy
 import copy
 import os
 import os
 import uuid
 import uuid
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 
 try:
 try:
     from qdrant_client import QdrantClient
     from qdrant_client import QdrantClient
@@ -161,8 +161,13 @@ class QdrantDB(BaseVectorDB):
             )
             )
 
 
     def query(
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         query contents from vector database based on vector similarity
         query contents from vector database based on vector similarity
         :param input_query: list of query string
         :param input_query: list of query string
@@ -174,8 +179,11 @@ class QdrantDB(BaseVectorDB):
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         generated or not
         generated or not
         :type skip_embedding: bool
         :type skip_embedding: bool
-        :return: The context of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         """
         if not skip_embedding:
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
             query_vector = self.embedder.embedding_fn([input_query])[0]
@@ -202,14 +210,17 @@ class QdrantDB(BaseVectorDB):
             limit=n_results,
             limit=n_results,
         )
         )
 
 
-        response = []
+        contexts = []
         for result in results:
         for result in results:
             context = result.payload["text"]
             context = result.payload["text"]
-            metadata = result.payload["metadata"]
-            source = metadata["url"]
-            doc_id = metadata["doc_id"]
-            response.append(tuple((context, source, doc_id)))
-        return response
+            if citations:
+                metadata = result.payload["metadata"]
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
 
     def count(self) -> int:
     def count(self) -> int:
         response = self.client.get_collection(collection_name=self.collection_name)
         response = self.client.get_collection(collection_name=self.collection_name)

+ 35 - 11
embedchain/vectordb/weaviate.py

@@ -1,6 +1,6 @@
 import copy
 import copy
 import os
 import os
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 
 try:
 try:
     import weaviate
     import weaviate
@@ -58,10 +58,14 @@ class WeaviateDB(BaseVectorDB):
             raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
             raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
 
 
         self.index_name = self._get_index_name()
         self.index_name = self._get_index_name()
-        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
+        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
         if not self.client.schema.exists(self.index_name):
         if not self.client.schema.exists(self.index_name):
             # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
             # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
             # The none vectorizer is crucial as we have our own custom embedding function
             # The none vectorizer is crucial as we have our own custom embedding function
+            """
+            TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
+            Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
+            """
             class_obj = {
             class_obj = {
                 "classes": [
                 "classes": [
                     {
                     {
@@ -106,10 +110,6 @@ class WeaviateDB(BaseVectorDB):
                                 "name": "app_id",
                                 "name": "app_id",
                                 "dataType": ["text"],
                                 "dataType": ["text"],
                             },
                             },
-                            {
-                                "name": "text",
-                                "dataType": ["text"],
-                            },
                         ],
                         ],
                     },
                     },
                 ]
                 ]
@@ -195,8 +195,13 @@ class WeaviateDB(BaseVectorDB):
                 batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
                 batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
 
 
     def query(
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         query contents from vector database based on vector similarity
         query contents from vector database based on vector similarity
         :param input_query: list of query string
         :param input_query: list of query string
@@ -208,15 +213,23 @@ class WeaviateDB(BaseVectorDB):
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         generated or not
         generated or not
         :type skip_embedding: bool
         :type skip_embedding: bool
-        :return: The context of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         """
         if not skip_embedding:
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
             query_vector = self.embedder.embedding_fn([input_query])[0]
         else:
         else:
             query_vector = input_query
             query_vector = input_query
+
         keys = set(where.keys() if where is not None else set())
         keys = set(where.keys() if where is not None else set())
         data_fields = ["text"]
         data_fields = ["text"]
+
+        if citations:
+            data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
+
         if len(keys.intersection(self.metadata_keys)) != 0:
         if len(keys.intersection(self.metadata_keys)) != 0:
             weaviate_where_operands = []
             weaviate_where_operands = []
             for key in keys:
             for key in keys:
@@ -247,7 +260,18 @@ class WeaviateDB(BaseVectorDB):
                 .with_limit(n_results)
                 .with_limit(n_results)
                 .do()
                 .do()
             )
             )
-        contexts = results["data"]["Get"].get(self.index_name)
+
+        docs = results["data"]["Get"].get(self.index_name)
+        contexts = []
+        for doc in docs:
+            context = doc["text"]
+            if citations:
+                metadata = doc["metadata"][0]
+                source = metadata["url"]
+                doc_id = metadata["doc_id"]
+                contexts.append((context, source, doc_id))
+            else:
+                contexts.append(context)
         return contexts
         return contexts
 
 
     def set_collection_name(self, name: str):
     def set_collection_name(self, name: str):

+ 21 - 10
embedchain/vectordb/zilliz.py

@@ -1,5 +1,5 @@
 import logging
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 
 from embedchain.config import ZillizDBConfig
 from embedchain.config import ZillizDBConfig
 from embedchain.helper.json_serializable import register_deserializable
 from embedchain.helper.json_serializable import register_deserializable
@@ -127,8 +127,13 @@ class ZillizVectorDB(BaseVectorDB):
         self.client.flush(self.config.collection_name)
         self.client.flush(self.config.collection_name)
 
 
     def query(
     def query(
-        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
-    ) -> List[Tuple[str, str, str]]:
+        self,
+        input_query: List[str],
+        n_results: int,
+        where: Dict[str, any],
+        skip_embedding: bool,
+        citations: bool = False,
+    ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         """
         Query contents from vector data base based on vector similarity
         Query contents from vector data base based on vector similarity
 
 
@@ -139,8 +144,11 @@ class ZillizVectorDB(BaseVectorDB):
         :param where: to filter data
         :param where: to filter data
         :type where: str
         :type where: str
         :raises InvalidDimensionException: Dimensions do not match.
         :raises InvalidDimensionException: Dimensions do not match.
-        :return: The context of the document that matched your query, url of the source, doc_id
-        :rtype: List[Tuple[str,str,str]]
+        :param citations: we use citations boolean param to return context along with the answer.
+        :type citations: bool, default is False.
+        :return: The content of the document that matched your query,
+        along with url of the source and doc_id (if citations flag is true)
+        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
         """
         """
 
 
         if self.collection.is_empty:
         if self.collection.is_empty:
@@ -170,14 +178,17 @@ class ZillizVectorDB(BaseVectorDB):
                 output_fields=output_fields,
                 output_fields=output_fields,
             )
             )
 
 
-        doc_list = []
+        contexts = []
         for query in query_result:
         for query in query_result:
             data = query[0]["entity"]
             data = query[0]["entity"]
             context = data["text"]
             context = data["text"]
-            source = data["url"]
-            doc_id = data["doc_id"]
-            doc_list.append(tuple((context, source, doc_id)))
-        return doc_list
+            if citations:
+                source = data["url"]
+                doc_id = data["doc_id"]
+                contexts.append(tuple((context, source, doc_id)))
+            else:
+                contexts.append(context)
+        return contexts
 
 
     def count(self) -> int:
     def count(self) -> int:
         """
         """

+ 1 - 1
poetry.lock

@@ -7141,4 +7141,4 @@ whatsapp = ["flask", "twilio"]
 [metadata]
 [metadata]
 lock-version = "2.0"
 lock-version = "2.0"
 python-versions = ">=3.9,<3.13"
 python-versions = ">=3.9,<3.13"
-content-hash = "0b83ba3fd2485b3b4aa3c6a7534b214378d349538f7eb63c65768aafecdfad60"
+content-hash = "0b83ba3fd2485b3b4aa3c6a7534b214378d349538f7eb63c65768aafecdfad60"

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.0.88"
+version = "0.0.89"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 16 - 6
tests/vectordb/test_chroma_db.py

@@ -163,10 +163,12 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
 
 
     assert data == expected_value
     assert data == expected_value
 
 
-    data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
-    expected_value = [("document", "url_1", "doc_id_1")]
+    data_without_citations = app_with_settings.db.query(
+        input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
+    )
+    expected_value_without_citations = ["document"]
+    assert data_without_citations == expected_value_without_citations
 
 
-    assert data == expected_value
     app_with_settings.db.reset()
     app_with_settings.db.reset()
 
 
 
 
@@ -326,8 +328,16 @@ def test_chroma_db_collection_query(app_with_settings):
 
 
     assert app_with_settings.db.count() == 2
     assert app_with_settings.db.count() == 2
 
 
-    data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
-    expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
+    data_without_citations = app_with_settings.db.query(
+        input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
+    )
+    expected_value_without_citations = ["document", "document2"]
+    assert data_without_citations == expected_value_without_citations
+
+    data_with_citations = app_with_settings.db.query(
+        input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
+    )
+    expected_value_with_citations = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
+    assert data_with_citations == expected_value_with_citations
 
 
-    assert data == expected_value
     app_with_settings.db.reset()
     app_with_settings.db.reset()

+ 10 - 8
tests/vectordb/test_elasticsearch_db.py

@@ -60,12 +60,16 @@ class TestEsDB(unittest.TestCase):
 
 
         # Query the database for the documents that are most similar to the query "This is a document".
         # Query the database for the documents that are most similar to the query "This is a document".
         query = ["This is a document"]
         query = ["This is a document"]
-        results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
+        results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
+        expected_results_without_citations = ["This is a document.", "This is another document."]
+        self.assertEqual(results_without_citations, expected_results_without_citations)
 
 
-        # Assert that the results are correct.
-        self.assertEqual(
-            results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
-        )
+        results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
+        expected_results_with_citations = [
+            ("This is a document.", "url_1", "doc_id_1"),
+            ("This is another document.", "url_2", "doc_id_2"),
+        ]
+        self.assertEqual(results_with_citations, expected_results_with_citations)
 
 
     @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
     @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
     def test_query_with_skip_embedding(self, mock_client):
     def test_query_with_skip_embedding(self, mock_client):
@@ -111,9 +115,7 @@ class TestEsDB(unittest.TestCase):
         results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
         results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
 
 
         # Assert that the results are correct.
         # Assert that the results are correct.
-        self.assertEqual(
-            results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
-        )
+        self.assertEqual(results, ["This is a document.", "This is another document."])
 
 
     def test_init_without_url(self):
     def test_init_without_url(self):
         # Make sure it's not loaded from env
         # Make sure it's not loaded from env

+ 0 - 4
tests/vectordb/test_weaviate.py

@@ -75,10 +75,6 @@ class TestWeaviateDb(unittest.TestCase):
                             "name": "app_id",
                             "name": "app_id",
                             "dataType": ["text"],
                             "dataType": ["text"],
                         },
                         },
-                        {
-                            "name": "text",
-                            "dataType": ["text"],
-                        },
                     ],
                     ],
                 },
                 },
             ]
             ]

+ 30 - 4
tests/vectordb/test_zilliz_db.py

@@ -129,7 +129,7 @@ class TestZillizDBCollection:
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
 
 
             # Assert that MilvusClient.search was called with the correct parameters
             # Assert that MilvusClient.search was called with the correct parameters
-            mock_search.assert_called_once_with(
+            mock_search.assert_called_with(
                 collection_name=mock_config.collection_name,
                 collection_name=mock_config.collection_name,
                 data=["query_text"],
                 data=["query_text"],
                 limit=1,
                 limit=1,
@@ -137,7 +137,20 @@ class TestZillizDBCollection:
             )
             )
 
 
             # Assert that the query result matches the expected result
             # Assert that the query result matches the expected result
-            assert query_result == [("result_doc", "url_1", "doc_id_1")]
+            assert query_result == ["result_doc"]
+
+            query_result_with_citations = zilliz_db.query(
+                input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
+            )
+
+            mock_search.assert_called_with(
+                collection_name=mock_config.collection_name,
+                data=["query_text"],
+                limit=1,
+                output_fields=["text", "url", "doc_id"],
+            )
+
+            assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
 
 
     @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
     @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
     @patch("embedchain.vectordb.zilliz.connections", autospec=True)
     @patch("embedchain.vectordb.zilliz.connections", autospec=True)
@@ -168,7 +181,7 @@ class TestZillizDBCollection:
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
 
 
             # Assert that MilvusClient.search was called with the correct parameters
             # Assert that MilvusClient.search was called with the correct parameters
-            mock_search.assert_called_once_with(
+            mock_search.assert_called_with(
                 collection_name=mock_config.collection_name,
                 collection_name=mock_config.collection_name,
                 data=["query_vector"],
                 data=["query_vector"],
                 limit=1,
                 limit=1,
@@ -176,4 +189,17 @@ class TestZillizDBCollection:
             )
             )
 
 
             # Assert that the query result matches the expected result
             # Assert that the query result matches the expected result
-            assert query_result == [("result_doc", "url_1", "doc_id_1")]
+            assert query_result == ["result_doc"]
+
+            query_result_with_citations = zilliz_db.query(
+                input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
+            )
+
+            mock_search.assert_called_with(
+                collection_name=mock_config.collection_name,
+                data=["query_vector"],
+                limit=1,
+                output_fields=["text", "url", "doc_id"],
+            )
+
+            assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]