Pārlūkot izejas kodu

[Feature] Return score when doing search in vectorDB (#1060)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 gadu atpakaļ
vecāks
revīzija
c0aafd38c9

+ 2 - 2
embedchain/embedchain.py

@@ -515,7 +515,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[Dict] = None,
         citations: bool = False,
         **kwargs: Dict[str, Any],
-    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
+    ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -566,7 +566,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[Dict[str, str]] = None,
         citations: bool = False,
         **kwargs: Dict[str, Any],
-    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
+    ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
         """
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an

+ 2 - 1
embedchain/vectordb/chroma.py

@@ -200,7 +200,7 @@ class ChromaDB(BaseVectorDB):
         skip_embedding: bool,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         Query contents from vector database based on vector similarity
 
@@ -250,6 +250,7 @@ class ChromaDB(BaseVectorDB):
             context = result[0].page_content
             if citations:
                 metadata = result[0].metadata
+                metadata["score"] = result[1]
                 contexts.append((context, metadata))
             else:
                 contexts.append(context)

+ 2 - 1
embedchain/vectordb/elasticsearch.py

@@ -164,7 +164,7 @@ class ElasticsearchDB(BaseVectorDB):
         skip_embedding: bool,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         query contents from vector data base based on vector similarity
 
@@ -210,6 +210,7 @@ class ElasticsearchDB(BaseVectorDB):
             context = doc["_source"]["text"]
             if citations:
                 metadata = doc["_source"]["metadata"]
+                metadata["score"] = doc["_score"]
                 contexts.append(tuple((context, metadata)))
             else:
                 contexts.append(context)

+ 6 - 4
embedchain/vectordb/opensearch.py

@@ -169,7 +169,7 @@ class OpenSearchDB(BaseVectorDB):
         skip_embedding: bool,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         query contents from vector data base based on vector similarity
 
@@ -202,7 +202,7 @@ class OpenSearchDB(BaseVectorDB):
         if "app_id" in where:
             app_id = where["app_id"]
             pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}}
-        docs = docsearch.similarity_search(
+        docs = docsearch.similarity_search_with_score(
             input_query,
             search_type="script_scoring",
             space_type="cosinesimil",
@@ -215,10 +215,12 @@ class OpenSearchDB(BaseVectorDB):
         )
 
         contexts = []
-        for doc in docs:
+        for doc, score in docs:
             context = doc.page_content
             if citations:
-                contexts.append(tuple((context, doc.metadata)))
+                metadata = doc.metadata
+                metadata["score"] = score
+                contexts.append(tuple((context, metadata)))
             else:
                 contexts.append(context)
         return contexts

+ 2 - 1
embedchain/vectordb/pinecone.py

@@ -127,7 +127,7 @@ class PineconeDB(BaseVectorDB):
         skip_embedding: bool,
         citations: bool = False,
         **kwargs: Optional[Dict[str, any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -154,6 +154,7 @@ class PineconeDB(BaseVectorDB):
             metadata = doc["metadata"]
             context = metadata["text"]
             if citations:
+                metadata["score"] = doc["score"]
                 contexts.append(tuple((context, metadata)))
             else:
                 contexts.append(context)

+ 2 - 1
embedchain/vectordb/qdrant.py

@@ -170,7 +170,7 @@ class QdrantDB(BaseVectorDB):
         skip_embedding: bool,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -219,6 +219,7 @@ class QdrantDB(BaseVectorDB):
             context = result.payload["text"]
             if citations:
                 metadata = result.payload["metadata"]
+                metadata["score"] = result.score
                 contexts.append(tuple((context, metadata)))
             else:
                 contexts.append(context)

+ 5 - 1
embedchain/vectordb/weaviate.py

@@ -205,7 +205,7 @@ class WeaviateDB(BaseVectorDB):
         skip_embedding: bool,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
@@ -255,6 +255,7 @@ class WeaviateDB(BaseVectorDB):
                 .with_where(weaviate_where_clause)
                 .with_near_vector({"vector": query_vector})
                 .with_limit(n_results)
+                .with_additional(["distance"])
                 .do()
             )
         else:
@@ -262,6 +263,7 @@ class WeaviateDB(BaseVectorDB):
                 self.client.query.get(self.index_name, data_fields)
                 .with_near_vector({"vector": query_vector})
                 .with_limit(n_results)
+                .with_additional(["distance"])
                 .do()
             )
 
@@ -271,6 +273,8 @@ class WeaviateDB(BaseVectorDB):
             context = doc["text"]
             if citations:
                 metadata = doc["metadata"][0]
+                score = doc["_additional"]["distance"]
+                metadata["score"] = score
                 contexts.append((context, metadata))
             else:
                 contexts.append(context)

+ 10 - 4
embedchain/vectordb/zilliz.py

@@ -135,7 +135,7 @@ class ZillizVectorDB(BaseVectorDB):
         skip_embedding: bool,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         Query contents from vector data base based on vector similarity
 
@@ -159,7 +159,7 @@ class ZillizVectorDB(BaseVectorDB):
         if not isinstance(where, str):
             where = None
 
-        output_fields = ["text", "url", "doc_id"]
+        output_fields = ["*"]
         if skip_embedding:
             query_vector = input_query
             query_result = self.client.search(
@@ -181,12 +181,18 @@ class ZillizVectorDB(BaseVectorDB):
                 output_fields=output_fields,
                 **kwargs,
             )
-
+        query_result = query_result[0]
         contexts = []
         for query in query_result:
-            data = query[0]["entity"]
+            data = query["entity"]
+            score = query["distance"]
             context = data["text"]
+
+            if "embeddings" in data:
+                data.pop("embeddings")
+
             if citations:
+                data["score"] = score
                 contexts.append(tuple((context, data)))
             else:
                 contexts.append(context)

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.44"
+version = "0.1.45"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 16 - 2
tests/vectordb/test_chroma_db.py

@@ -342,8 +342,22 @@ def test_chroma_db_collection_query(app_with_settings):
         input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
     )
     expected_value_with_citations = [
-        ("document", {"url": "url_1", "doc_id": "doc_id_1"}),
-        ("document2", {"url": "url_2", "doc_id": "doc_id_2"}),
+        (
+            "document",
+            {
+                "url": "url_1",
+                "doc_id": "doc_id_1",
+                "score": 0.0,
+            },
+        ),
+        (
+            "document2",
+            {
+                "url": "url_2",
+                "doc_id": "doc_id_2",
+                "score": 1.0,
+            },
+        ),
     ]
     assert data_with_citations == expected_value_with_citations
 

+ 2 - 2
tests/vectordb/test_elasticsearch_db.py

@@ -66,8 +66,8 @@ class TestEsDB(unittest.TestCase):
 
         results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
         expected_results_with_citations = [
-            ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1"}),
-            ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2"}),
+            ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
+            ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
         ]
         self.assertEqual(results_with_citations, expected_results_with_citations)
 

+ 22 - 8
tests/vectordb/test_zilliz_db.py

@@ -123,7 +123,14 @@ class TestZillizDBCollection:
         # Mock the MilvusClient search method
         with patch.object(zilliz_db.client, "search") as mock_search:
             # Mock the search result
-            mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
+            mock_search.return_value = [
+                [
+                    {
+                        "distance": 0.5,
+                        "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
+                    }
+                ]
+            ]
 
             # Call the query method with skip_embedding=True
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
@@ -133,7 +140,7 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 data=["query_text"],
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
 
             # Assert that the query result matches the expected result
@@ -147,11 +154,11 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 data=["query_text"],
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
 
             assert query_result_with_citations == [
-                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"})
+                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
             ]
 
     @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@@ -177,7 +184,14 @@ class TestZillizDBCollection:
             mock_embedder.embedding_fn.return_value = ["query_vector"]
 
             # Mock the search result
-            mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
+            mock_search.return_value = [
+                [
+                    {
+                        "distance": 0.0,
+                        "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
+                    }
+                ]
+            ]
 
             # Call the query method with skip_embedding=False
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
@@ -187,7 +201,7 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 data=["query_vector"],
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
 
             # Assert that the query result matches the expected result
@@ -201,9 +215,9 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 data=["query_vector"],
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
 
             assert query_result_with_citations == [
-                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"})
+                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.0})
             ]