Explorar o código

[Feature] Update `db.query` to return source of context (#831)

Deven Patel hai 1 ano
pai
achega
d77e8da3f3

+ 1 - 0
.gitignore

@@ -174,4 +174,5 @@ test-db
 
 notebooks/*.yaml
 .ipynb_checkpoints/
+
 !configs/*.yaml

+ 6 - 2
embedchain/embedchain.py

@@ -500,13 +500,17 @@ class EmbedChain(JSONSerializable):
 
             db_query = ClipProcessor.get_text_features(query=input_query)
 
-        contents = self.db.query(
+        contexts = self.db.query(
             input_query=db_query,
             n_results=query_config.number_documents,
             where=where,
             skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
         )
-        return contents
+
+        if len(contexts) > 0 and isinstance(contexts[0], tuple):
+            contexts = list(map(lambda x: x[0], contexts))
+
+        return contexts
 
     def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
         """

+ 7 - 5
embedchain/factory.py

@@ -41,15 +41,15 @@ class LlmFactory:
 
 class EmbedderFactory:
     provider_to_class = {
+        "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
         "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
         "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
-        "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
-        "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
         "openai": "embedchain.embedder.openai.OpenAIEmbedder",
+        "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
     }
     provider_to_config_class = {
-        "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
         "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
+        "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
     }
 
     @classmethod
@@ -72,16 +72,18 @@ class VectorDBFactory:
         "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
         "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
         "pinecone": "embedchain.vectordb.pinecone.PineconeDB",
-        "weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
         "qdrant": "embedchain.vectordb.qdrant.QdrantDB",
+        "weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
+        "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
     }
     provider_to_config_class = {
         "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
         "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
         "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
         "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
-        "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
         "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
+        "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
+        "zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
     }
 
     @classmethod

+ 14 - 6
embedchain/vectordb/chroma.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from chromadb import Collection, QueryResult
 from langchain.docstore.document import Document
@@ -191,7 +191,9 @@ class ChromaDB(BaseVectorDB):
             )
         ]
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+    def query(
+        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
+    ) -> List[Tuple[str, str, str]]:
         """
         Query contents from vector database based on vector similarity
 
@@ -204,8 +206,8 @@ class ChromaDB(BaseVectorDB):
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
         :raises InvalidDimensionException: Dimensions do not match.
-        :return: The content of the document that matched your query.
-        :rtype: List[str]
+        :return: The content of the document that matched your query, url of the source, doc_id
+        :rtype: List[Tuple[str,str,str]]
         """
         try:
             if skip_embedding:
@@ -231,8 +233,14 @@ class ChromaDB(BaseVectorDB):
                 " embeddings, is used to retrieve an embedding from the database."
             ) from None
         results_formatted = self._format_result(result)
-        contents = [result[0].page_content for result in results_formatted]
-        return contents
+        contexts = []
+        for result in results_formatted:
+            context = result[0].page_content
+            metadata = result[0].metadata
+            source = metadata["url"]
+            doc_id = metadata["doc_id"]
+            contexts.append((context, source, doc_id))
+        return contexts
 
     def set_collection_name(self, name: str):
         """

+ 17 - 7
embedchain/vectordb/elasticsearch.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 try:
     from elasticsearch import Elasticsearch
@@ -135,7 +135,9 @@ class ElasticsearchDB(BaseVectorDB):
         bulk(self.client, docs)
         self.client.indices.refresh(index=self._get_index())
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+    def query(
+        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
+    ) -> List[Tuple[str, str, str]]:
         """
         query contents from vector data base based on vector similarity
 
@@ -147,8 +149,9 @@ class ElasticsearchDB(BaseVectorDB):
         :type where: Dict[str, any]
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
-        :return: Database contents that are the result of the query
-        :rtype: List[str]
+        :return: The context of the document that matched your query, url of the source, doc_id
+
+        :rtype: List[Tuple[str,str,str]]
         """
         if skip_embedding:
             query_vector = input_query
@@ -156,6 +159,7 @@ class ElasticsearchDB(BaseVectorDB):
             input_query_vector = self.embedder.embedding_fn(input_query)
             query_vector = input_query_vector[0]
 
+        # `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
         query = {
             "script_score": {
                 "query": {"bool": {"must": [{"exists": {"field": "text"}}]}},
@@ -167,11 +171,17 @@ class ElasticsearchDB(BaseVectorDB):
         }
         if "app_id" in where:
             app_id = where["app_id"]
-            query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}]
-        _source = ["text"]
+            query["script_score"]["query"] = {"match": {"metadata.app_id": app_id}}
+        _source = ["text", "metadata.url", "metadata.doc_id"]
         response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
         docs = response["hits"]["hits"]
-        contents = [doc["_source"]["text"] for doc in docs]
+        contents = []
+        for doc in docs:
+            context = doc["_source"]["text"]
+            metadata = doc["_source"]["metadata"]
+            source = metadata["url"]
+            doc_id = metadata["doc_id"]
+            contents.append(tuple((context, source, doc_id)))
         return contents
 
     def set_collection_name(self, name: str):

+ 13 - 5
embedchain/vectordb/opensearch.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Dict, List, Optional, Set
+from typing import Dict, List, Optional, Set, Tuple
 
 try:
     from opensearchpy import OpenSearch
@@ -145,7 +145,9 @@ class OpenSearchDB(BaseVectorDB):
         bulk(self.client, docs)
         self.client.indices.refresh(index=self._get_index())
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+    def query(
+        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
+    ) -> List[Tuple[str, str, str]]:
         """
         query contents from vector data base based on vector similarity
 
@@ -157,8 +159,8 @@ class OpenSearchDB(BaseVectorDB):
         :type where: Dict[str, any]
         :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
         :type skip_embedding: bool
-        :return: Database contents that are the result of the query
-        :rtype: List[str]
+        :return: The content of the document that matched your query, url of the source, doc_id
+        :rtype: List[Tuple[str,str,str]]
         """
         # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
         embeddings = OpenAIEmbeddings()
@@ -185,7 +187,13 @@ class OpenSearchDB(BaseVectorDB):
             pre_filter=pre_filter,
             k=n_results,
         )
-        contents = [doc.page_content for doc in docs]
+
+        contents = []
+        for doc in docs:
+            context = doc.page_content
+            source = doc.metadata["url"]
+            doc_id = doc.metadata["doc_id"]
+            contents.append(tuple((context, source, doc_id)))
         return contents
 
     def set_collection_name(self, name: str):

+ 15 - 7
embedchain/vectordb/pinecone.py

@@ -1,5 +1,5 @@
 import os
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
 
 try:
     import pinecone
@@ -118,7 +118,9 @@ class PineconeDB(BaseVectorDB):
         for i in range(0, len(docs), self.BATCH_SIZE):
             self.client.upsert(docs[i : i + self.BATCH_SIZE])
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+    def query(
+        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
+    ) -> List[Tuple[str, str, str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -129,16 +131,22 @@ class PineconeDB(BaseVectorDB):
         :type where: Dict[str, any]
         :param skip_embedding: Optional. if True, input_query is already embedded
         :type skip_embedding: bool
-        :return: Database contents that are the result of the query
-        :rtype: List[str]
+        :return: The content of the document that matched your query, url of the source, doc_id
+        :rtype: List[Tuple[str,str,str]]
         """
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
         else:
             query_vector = input_query
-        contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
-        embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
-        return embeddings
+        data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
+        contents = []
+        for doc in data["matches"]:
+            metadata = doc["metadata"]
+            context = metadata["text"]
+            source = metadata["url"]
+            doc_id = metadata["doc_id"]
+            contents.append(tuple((context, source, doc_id)))
+        return contents
 
     def set_collection_name(self, name: str):
         """

+ 24 - 5
embedchain/vectordb/qdrant.py

@@ -1,7 +1,7 @@
 import copy
 import os
 import uuid
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
 
 try:
     from qdrant_client import QdrantClient
@@ -160,7 +160,9 @@ class QdrantDB(BaseVectorDB):
                 ),
             )
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+    def query(
+        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
+    ) -> List[Tuple[str, str, str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -172,8 +174,8 @@ class QdrantDB(BaseVectorDB):
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         generated or not
         :type skip_embedding: bool
-        :return: Database contents that are the result of the query
-        :rtype: List[str]
+        :return: The context of the document that matched your query, url of the source, doc_id
+        :rtype: List[Tuple[str,str,str]]
         """
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
@@ -199,9 +201,14 @@ class QdrantDB(BaseVectorDB):
             query_vector=query_vector,
             limit=n_results,
         )
+
         response = []
         for result in results:
-            response.append(result.payload.get("text", ""))
+            context = result.payload["text"]
+            metadata = result.payload["metadata"]
+            source = metadata["url"]
+            doc_id = metadata["doc_id"]
+            response.append(tuple((context, source, doc_id)))
         return response
 
     def count(self) -> int:
@@ -211,3 +218,15 @@ class QdrantDB(BaseVectorDB):
     def reset(self):
         self.client.delete_collection(collection_name=self.collection_name)
         self._initialize()
+
+    def set_collection_name(self, name: str):
+        """
+        Set the name of the collection. A collection is an isolated space for vectors.
+
+        :param name: Name of the collection.
+        :type name: str
+        """
+        if not isinstance(name, str):
+            raise TypeError("Collection name must be a string")
+        self.config.collection_name = name
+        self.collection_name = self._get_or_create_collection()

+ 11 - 11
embedchain/vectordb/weaviate.py

@@ -1,6 +1,6 @@
 import copy
 import os
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
 
 try:
     import weaviate
@@ -194,7 +194,9 @@ class WeaviateDB(BaseVectorDB):
                 )
                 batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+    def query(
+        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
+    ) -> List[Tuple[str, str, str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -206,14 +208,15 @@ class WeaviateDB(BaseVectorDB):
         :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
         generated or not
         :type skip_embedding: bool
-        :return: Database contents that are the result of the query
-        :rtype: List[str]
+        :return: The context of the document that matched your query, url of the source, doc_id
+        :rtype: List[Tuple[str,str,str]]
         """
         if not skip_embedding:
             query_vector = self.embedder.embedding_fn([input_query])[0]
         else:
             query_vector = input_query
         keys = set(where.keys() if where is not None else set())
+        data_fields = ["text"]
         if len(keys.intersection(self.metadata_keys)) != 0:
             weaviate_where_operands = []
             for key in keys:
@@ -231,7 +234,7 @@ class WeaviateDB(BaseVectorDB):
                 weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
 
             results = (
-                self.client.query.get(self.index_name, ["text"])
+                self.client.query.get(self.index_name, data_fields)
                 .with_where(weaviate_where_clause)
                 .with_near_vector({"vector": query_vector})
                 .with_limit(n_results)
@@ -239,16 +242,13 @@ class WeaviateDB(BaseVectorDB):
             )
         else:
             results = (
-                self.client.query.get(self.index_name, ["text"])
+                self.client.query.get(self.index_name, data_fields)
                 .with_near_vector({"vector": query_vector})
                 .with_limit(n_results)
                 .do()
             )
-        matched_tokens = []
-        for result in results["data"]["Get"].get(self.index_name):
-            matched_tokens.append(result["text"])
-
-        return matched_tokens
+        contexts = results["data"]["Get"].get(self.index_name)
+        return contexts
 
     def set_collection_name(self, name: str):
         """

+ 16 - 8
embedchain/vectordb/zilliz.py

@@ -1,4 +1,5 @@
-from typing import Dict, List, Optional
+import logging
+from typing import Dict, List, Optional, Tuple
 
 from embedchain.config import ZillizDBConfig
 from embedchain.helper.json_serializable import register_deserializable
@@ -61,6 +62,7 @@ class ZillizVectorDB(BaseVectorDB):
         :type name: str
         """
         if utility.has_collection(name):
+            logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
             self.collection = Collection(name)
         else:
             fields = [
@@ -124,7 +126,9 @@ class ZillizVectorDB(BaseVectorDB):
         self.collection.flush()
         self.client.flush(self.config.collection_name)
 
-    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+    def query(
+        self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
+    ) -> List[Tuple[str, str, str]]:
         """
         Query contents from vector data base based on vector similarity
 
@@ -135,8 +139,8 @@ class ZillizVectorDB(BaseVectorDB):
         :param where: to filter data
         :type where: str
         :raises InvalidDimensionException: Dimensions do not match.
-        :return: The content of the document that matched your query.
-        :rtype: List[str]
+        :return: The context of the document that matched your query, url of the source, doc_id
+        :rtype: List[Tuple[str,str,str]]
         """
 
         if self.collection.is_empty:
@@ -145,13 +149,14 @@ class ZillizVectorDB(BaseVectorDB):
         if not isinstance(where, str):
             where = None
 
+        output_fields = ["text", "url", "doc_id"]
         if skip_embedding:
             query_vector = input_query
             query_result = self.client.search(
                 collection_name=self.config.collection_name,
                 data=query_vector,
                 limit=n_results,
-                output_fields=["text"],
+                output_fields=output_fields,
             )
 
         else:
@@ -162,13 +167,16 @@ class ZillizVectorDB(BaseVectorDB):
                 collection_name=self.config.collection_name,
                 data=[query_vector],
                 limit=n_results,
-                output_fields=["text"],
+                output_fields=output_fields,
             )
 
         doc_list = []
         for query in query_result:
-            doc_list.append(query[0]["entity"]["text"])
-
+            data = query[0]["entity"]
+            context = data["text"]
+            source = data["url"]
+            doc_id = data["doc_id"]
+            doc_list.append(tuple((context, source, doc_id)))
         return doc_list
 
     def count(self) -> int:

+ 35 - 3
tests/vectordb/test_chroma_db.py

@@ -146,7 +146,7 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
     app_with_settings.db.add(
         embeddings=[[0, 0, 0]],
         documents=["document"],
-        metadatas=[{"value": "somevalue"}],
+        metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
         ids=["id"],
         skip_embedding=True,
     )
@@ -158,13 +158,13 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
         "documents": ["document"],
         "embeddings": None,
         "ids": ["id"],
-        "metadatas": [{"value": "somevalue"}],
+        "metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
     }
 
     assert data == expected_value
 
     data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
-    expected_value = ["document"]
+    expected_value = [("document", "url_1", "doc_id_1")]
 
     assert data == expected_value
     app_with_settings.db.reset()
@@ -299,3 +299,35 @@ def test_chroma_db_collection_reset():
     app2.db.reset()
     app3.db.reset()
     app4.db.reset()
+
+
+def test_chroma_db_collection_query(app_with_settings):
+    app_with_settings.db.reset()
+
+    assert app_with_settings.db.count() == 0
+
+    app_with_settings.db.add(
+        embeddings=[[0, 0, 0]],
+        documents=["document"],
+        metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
+        ids=["id"],
+        skip_embedding=True,
+    )
+
+    assert app_with_settings.db.count() == 1
+
+    app_with_settings.db.add(
+        embeddings=[[0, 1, 0]],
+        documents=["document2"],
+        metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
+        ids=["id2"],
+        skip_embedding=True,
+    )
+
+    assert app_with_settings.db.count() == 2
+
+    data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
+    expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
+
+    assert data == expected_value
+    app_with_settings.db.reset()

+ 30 - 8
tests/vectordb/test_elasticsearch_db.py

@@ -31,7 +31,7 @@ class TestEsDB(unittest.TestCase):
         # Create some dummy data.
         embeddings = [[1, 2, 3], [4, 5, 6]]
         documents = ["This is a document.", "This is another document."]
-        metadatas = [{}, {}]
+        metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
         ids = ["doc_1", "doc_2"]
 
         # Add the data to the database.
@@ -40,8 +40,17 @@ class TestEsDB(unittest.TestCase):
         search_response = {
             "hits": {
                 "hits": [
-                    {"_source": {"text": "This is a document."}, "_score": 0.9},
-                    {"_source": {"text": "This is another document."}, "_score": 0.8},
+                    {
+                        "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
+                        "_score": 0.9,
+                    },
+                    {
+                        "_source": {
+                            "text": "This is another document.",
+                            "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
+                        },
+                        "_score": 0.8,
+                    },
                 ]
             }
         }
@@ -54,7 +63,9 @@ class TestEsDB(unittest.TestCase):
         results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
 
         # Assert that the results are correct.
-        self.assertEqual(results, ["This is a document.", "This is another document."])
+        self.assertEqual(
+            results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
+        )
 
     @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
     def test_query_with_skip_embedding(self, mock_client):
@@ -68,7 +79,7 @@ class TestEsDB(unittest.TestCase):
         # Create some dummy data.
         embeddings = [[1, 2, 3], [4, 5, 6]]
         documents = ["This is a document.", "This is another document."]
-        metadatas = [{}, {}]
+        metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
         ids = ["doc_1", "doc_2"]
 
         # Add the data to the database.
@@ -77,8 +88,17 @@ class TestEsDB(unittest.TestCase):
         search_response = {
             "hits": {
                 "hits": [
-                    {"_source": {"text": "This is a document."}, "_score": 0.9},
-                    {"_source": {"text": "This is another document."}, "_score": 0.8},
+                    {
+                        "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
+                        "_score": 0.9,
+                    },
+                    {
+                        "_source": {
+                            "text": "This is another document.",
+                            "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
+                        },
+                        "_score": 0.8,
+                    },
                 ]
             }
         }
@@ -91,7 +111,9 @@ class TestEsDB(unittest.TestCase):
         results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
 
         # Assert that the results are correct.
-        self.assertEqual(results, ["This is a document.", "This is another document."])
+        self.assertEqual(
+            results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
+        )
 
     def test_init_without_url(self):
         # Make sure it's not loaded from env

+ 6 - 6
tests/vectordb/test_zilliz_db.py

@@ -123,7 +123,7 @@ class TestZillizDBCollection:
         # Mock the MilvusClient search method
         with patch.object(zilliz_db.client, "search") as mock_search:
             # Mock the search result
-            mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
+            mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
 
             # Call the query method with skip_embedding=True
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
@@ -133,11 +133,11 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 data=["query_text"],
                 limit=1,
-                output_fields=["text"],
+                output_fields=["text", "url", "doc_id"],
             )
 
             # Assert that the query result matches the expected result
-            assert query_result == ["result_doc"]
+            assert query_result == [("result_doc", "url_1", "doc_id_1")]
 
     @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
     @patch("embedchain.vectordb.zilliz.connections", autospec=True)
@@ -162,7 +162,7 @@ class TestZillizDBCollection:
             mock_embedder.embedding_fn.return_value = ["query_vector"]
 
             # Mock the search result
-            mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
+            mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
 
             # Call the query method with skip_embedding=False
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
@@ -172,8 +172,8 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 data=["query_vector"],
                 limit=1,
-                output_fields=["text"],
+                output_fields=["text", "url", "doc_id"],
             )
 
             # Assert that the query result matches the expected result
-            assert query_result == ["result_doc"]
+            assert query_result == [("result_doc", "url_1", "doc_id_1")]