浏览代码

[Feature] Return score when doing search in vectorDB (#1060)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 年之前
父节点
当前提交
c0aafd38c9

+ 2 - 2
embedchain/embedchain.py

@@ -515,7 +515,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[Dict] = None,
         where: Optional[Dict] = None,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Dict[str, Any],
         **kwargs: Dict[str, Any],
-    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
+    ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
         """
         """
         Queries the vector database based on the given input query.
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -566,7 +566,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[Dict[str, str]] = None,
         where: Optional[Dict[str, str]] = None,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Dict[str, Any],
         **kwargs: Dict[str, Any],
-    ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
+    ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
         """
         """
         Queries the vector database on the given input query.
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an

+ 2 - 1
embedchain/vectordb/chroma.py

@@ -200,7 +200,7 @@ class ChromaDB(BaseVectorDB):
         skip_embedding: bool,
         skip_embedding: bool,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
         Query contents from vector database based on vector similarity
         Query contents from vector database based on vector similarity
 
 
@@ -250,6 +250,7 @@ class ChromaDB(BaseVectorDB):
             context = result[0].page_content
             context = result[0].page_content
             if citations:
             if citations:
                 metadata = result[0].metadata
                 metadata = result[0].metadata
+                metadata["score"] = result[1]
                 contexts.append((context, metadata))
                 contexts.append((context, metadata))
             else:
             else:
                 contexts.append(context)
                 contexts.append(context)

+ 2 - 1
embedchain/vectordb/elasticsearch.py

@@ -164,7 +164,7 @@ class ElasticsearchDB(BaseVectorDB):
         skip_embedding: bool,
         skip_embedding: bool,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
         query contents from vector data base based on vector similarity
         query contents from vector data base based on vector similarity
 
 
@@ -210,6 +210,7 @@ class ElasticsearchDB(BaseVectorDB):
             context = doc["_source"]["text"]
             context = doc["_source"]["text"]
             if citations:
             if citations:
                 metadata = doc["_source"]["metadata"]
                 metadata = doc["_source"]["metadata"]
+                metadata["score"] = doc["_score"]
                 contexts.append(tuple((context, metadata)))
                 contexts.append(tuple((context, metadata)))
             else:
             else:
                 contexts.append(context)
                 contexts.append(context)

+ 6 - 4
embedchain/vectordb/opensearch.py

@@ -169,7 +169,7 @@ class OpenSearchDB(BaseVectorDB):
         skip_embedding: bool,
         skip_embedding: bool,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
         query contents from vector data base based on vector similarity
         query contents from vector data base based on vector similarity
 
 
@@ -202,7 +202,7 @@ class OpenSearchDB(BaseVectorDB):
         if "app_id" in where:
         if "app_id" in where:
             app_id = where["app_id"]
             app_id = where["app_id"]
             pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}}
             pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}}
-        docs = docsearch.similarity_search(
+        docs = docsearch.similarity_search_with_score(
             input_query,
             input_query,
             search_type="script_scoring",
             search_type="script_scoring",
             space_type="cosinesimil",
             space_type="cosinesimil",
@@ -215,10 +215,12 @@ class OpenSearchDB(BaseVectorDB):
         )
         )
 
 
         contexts = []
         contexts = []
-        for doc in docs:
+        for doc, score in docs:
             context = doc.page_content
             context = doc.page_content
             if citations:
             if citations:
-                contexts.append(tuple((context, doc.metadata)))
+                metadata = doc.metadata
+                metadata["score"] = score
+                contexts.append(tuple((context, metadata)))
             else:
             else:
                 contexts.append(context)
                 contexts.append(context)
         return contexts
         return contexts

+ 2 - 1
embedchain/vectordb/pinecone.py

@@ -127,7 +127,7 @@ class PineconeDB(BaseVectorDB):
         skip_embedding: bool,
         skip_embedding: bool,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Optional[Dict[str, any]],
         **kwargs: Optional[Dict[str, any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
         query contents from vector database based on vector similarity
         query contents from vector database based on vector similarity
         :param input_query: list of query string
         :param input_query: list of query string
@@ -154,6 +154,7 @@ class PineconeDB(BaseVectorDB):
             metadata = doc["metadata"]
             metadata = doc["metadata"]
             context = metadata["text"]
             context = metadata["text"]
             if citations:
             if citations:
+                metadata["score"] = doc["score"]
                 contexts.append(tuple((context, metadata)))
                 contexts.append(tuple((context, metadata)))
             else:
             else:
                 contexts.append(context)
                 contexts.append(context)

+ 2 - 1
embedchain/vectordb/qdrant.py

@@ -170,7 +170,7 @@ class QdrantDB(BaseVectorDB):
         skip_embedding: bool,
         skip_embedding: bool,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
         query contents from vector database based on vector similarity
         query contents from vector database based on vector similarity
         :param input_query: list of query string
         :param input_query: list of query string
@@ -219,6 +219,7 @@ class QdrantDB(BaseVectorDB):
             context = result.payload["text"]
             context = result.payload["text"]
             if citations:
             if citations:
                 metadata = result.payload["metadata"]
                 metadata = result.payload["metadata"]
+                metadata["score"] = result.score
                 contexts.append(tuple((context, metadata)))
                 contexts.append(tuple((context, metadata)))
             else:
             else:
                 contexts.append(context)
                 contexts.append(context)

+ 5 - 1
embedchain/vectordb/weaviate.py

@@ -205,7 +205,7 @@ class WeaviateDB(BaseVectorDB):
         skip_embedding: bool,
         skip_embedding: bool,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
         query contents from vector database based on vector similarity
         query contents from vector database based on vector similarity
         :param input_query: list of query string
         :param input_query: list of query string
@@ -255,6 +255,7 @@ class WeaviateDB(BaseVectorDB):
                 .with_where(weaviate_where_clause)
                 .with_where(weaviate_where_clause)
                 .with_near_vector({"vector": query_vector})
                 .with_near_vector({"vector": query_vector})
                 .with_limit(n_results)
                 .with_limit(n_results)
+                .with_additional(["distance"])
                 .do()
                 .do()
             )
             )
         else:
         else:
@@ -262,6 +263,7 @@ class WeaviateDB(BaseVectorDB):
                 self.client.query.get(self.index_name, data_fields)
                 self.client.query.get(self.index_name, data_fields)
                 .with_near_vector({"vector": query_vector})
                 .with_near_vector({"vector": query_vector})
                 .with_limit(n_results)
                 .with_limit(n_results)
+                .with_additional(["distance"])
                 .do()
                 .do()
             )
             )
 
 
@@ -271,6 +273,8 @@ class WeaviateDB(BaseVectorDB):
             context = doc["text"]
             context = doc["text"]
             if citations:
             if citations:
                 metadata = doc["metadata"][0]
                 metadata = doc["metadata"][0]
+                score = doc["_additional"]["distance"]
+                metadata["score"] = score
                 contexts.append((context, metadata))
                 contexts.append((context, metadata))
             else:
             else:
                 contexts.append(context)
                 contexts.append(context)

+ 10 - 4
embedchain/vectordb/zilliz.py

@@ -135,7 +135,7 @@ class ZillizVectorDB(BaseVectorDB):
         skip_embedding: bool,
         skip_embedding: bool,
         citations: bool = False,
         citations: bool = False,
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+    ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
         Query contents from vector data base based on vector similarity
         Query contents from vector data base based on vector similarity
 
 
@@ -159,7 +159,7 @@ class ZillizVectorDB(BaseVectorDB):
         if not isinstance(where, str):
         if not isinstance(where, str):
             where = None
             where = None
 
 
-        output_fields = ["text", "url", "doc_id"]
+        output_fields = ["*"]
         if skip_embedding:
         if skip_embedding:
             query_vector = input_query
             query_vector = input_query
             query_result = self.client.search(
             query_result = self.client.search(
@@ -181,12 +181,18 @@ class ZillizVectorDB(BaseVectorDB):
                 output_fields=output_fields,
                 output_fields=output_fields,
                 **kwargs,
                 **kwargs,
             )
             )
-
+        query_result = query_result[0]
         contexts = []
         contexts = []
         for query in query_result:
         for query in query_result:
-            data = query[0]["entity"]
+            data = query["entity"]
+            score = query["distance"]
             context = data["text"]
             context = data["text"]
+
+            if "embeddings" in data:
+                data.pop("embeddings")
+
             if citations:
             if citations:
+                data["score"] = score
                 contexts.append(tuple((context, data)))
                 contexts.append(tuple((context, data)))
             else:
             else:
                 contexts.append(context)
                 contexts.append(context)

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.1.44"
+version = "0.1.45"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 16 - 2
tests/vectordb/test_chroma_db.py

@@ -342,8 +342,22 @@ def test_chroma_db_collection_query(app_with_settings):
         input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
         input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
     )
     )
     expected_value_with_citations = [
     expected_value_with_citations = [
-        ("document", {"url": "url_1", "doc_id": "doc_id_1"}),
-        ("document2", {"url": "url_2", "doc_id": "doc_id_2"}),
+        (
+            "document",
+            {
+                "url": "url_1",
+                "doc_id": "doc_id_1",
+                "score": 0.0,
+            },
+        ),
+        (
+            "document2",
+            {
+                "url": "url_2",
+                "doc_id": "doc_id_2",
+                "score": 1.0,
+            },
+        ),
     ]
     ]
     assert data_with_citations == expected_value_with_citations
     assert data_with_citations == expected_value_with_citations
 
 

+ 2 - 2
tests/vectordb/test_elasticsearch_db.py

@@ -66,8 +66,8 @@ class TestEsDB(unittest.TestCase):
 
 
         results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
         results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
         expected_results_with_citations = [
         expected_results_with_citations = [
-            ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1"}),
-            ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2"}),
+            ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
+            ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
         ]
         ]
         self.assertEqual(results_with_citations, expected_results_with_citations)
         self.assertEqual(results_with_citations, expected_results_with_citations)
 
 

+ 22 - 8
tests/vectordb/test_zilliz_db.py

@@ -123,7 +123,14 @@ class TestZillizDBCollection:
         # Mock the MilvusClient search method
         # Mock the MilvusClient search method
         with patch.object(zilliz_db.client, "search") as mock_search:
         with patch.object(zilliz_db.client, "search") as mock_search:
             # Mock the search result
             # Mock the search result
-            mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
+            mock_search.return_value = [
+                [
+                    {
+                        "distance": 0.5,
+                        "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
+                    }
+                ]
+            ]
 
 
             # Call the query method with skip_embedding=True
             # Call the query method with skip_embedding=True
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
@@ -133,7 +140,7 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 collection_name=mock_config.collection_name,
                 data=["query_text"],
                 data=["query_text"],
                 limit=1,
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
             )
 
 
             # Assert that the query result matches the expected result
             # Assert that the query result matches the expected result
@@ -147,11 +154,11 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 collection_name=mock_config.collection_name,
                 data=["query_text"],
                 data=["query_text"],
                 limit=1,
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
             )
 
 
             assert query_result_with_citations == [
             assert query_result_with_citations == [
-                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"})
+                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
             ]
             ]
 
 
     @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
     @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@@ -177,7 +184,14 @@ class TestZillizDBCollection:
             mock_embedder.embedding_fn.return_value = ["query_vector"]
             mock_embedder.embedding_fn.return_value = ["query_vector"]
 
 
             # Mock the search result
             # Mock the search result
-            mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
+            mock_search.return_value = [
+                [
+                    {
+                        "distance": 0.0,
+                        "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
+                    }
+                ]
+            ]
 
 
             # Call the query method with skip_embedding=False
             # Call the query method with skip_embedding=False
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
             query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
@@ -187,7 +201,7 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 collection_name=mock_config.collection_name,
                 data=["query_vector"],
                 data=["query_vector"],
                 limit=1,
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
             )
 
 
             # Assert that the query result matches the expected result
             # Assert that the query result matches the expected result
@@ -201,9 +215,9 @@ class TestZillizDBCollection:
                 collection_name=mock_config.collection_name,
                 collection_name=mock_config.collection_name,
                 data=["query_vector"],
                 data=["query_vector"],
                 limit=1,
                 limit=1,
-                output_fields=["text", "url", "doc_id"],
+                output_fields=["*"],
             )
             )
 
 
             assert query_result_with_citations == [
             assert query_result_with_citations == [
-                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"})
+                ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.0})
             ]
             ]