Sfoglia il codice sorgente

[Bugfix] fix qdrant and weaviate db integration (#1181)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 anno fa
parent
commit
2d9fbd4e49
3 ha cambiato i file con 131 aggiunte e 44 eliminazioni
  1. 42 17
      embedchain/vectordb/qdrant.py
  2. 86 24
      embedchain/vectordb/weaviate.py
  3. 3 3
      tests/vectordb/test_qdrant.py

+ 42 - 17
embedchain/vectordb/qdrant.py

@@ -11,6 +11,8 @@ try:
 except ImportError:
     raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
 
+from tqdm import tqdm
+
 from embedchain.config.vectordb.qdrant import QdrantDBConfig
 from embedchain.vectordb.base import BaseVectorDB
 
@@ -48,7 +50,6 @@ class QdrantDB(BaseVectorDB):
             raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
 
         self.collection_name = self._get_or_create_collection()
-        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
         all_collections = self.client.get_collections()
         collection_names = [collection.name for collection in all_collections.collections]
         if self.collection_name not in collection_names:
@@ -82,21 +83,23 @@ class QdrantDB(BaseVectorDB):
         :return: All the existing IDs
         :rtype: Set[str]
         """
-        if ids is None or len(ids) == 0:
-            return {"ids": []}
 
         keys = set(where.keys() if where is not None else set())
 
-        qdrant_must_filters = [
-            models.FieldCondition(
-                key="identifier",
-                match=models.MatchAny(
-                    any=ids,
-                ),
+        qdrant_must_filters = []
+
+        if ids:
+            qdrant_must_filters.append(
+                models.FieldCondition(
+                    key="identifier",
+                    match=models.MatchAny(
+                        any=ids,
+                    ),
+                )
             )
-        ]
-        if len(keys.intersection(self.metadata_keys)) != 0:
-            for key in keys.intersection(self.metadata_keys):
+
+        if len(keys) > 0:
+            for key in keys:
                 qdrant_must_filters.append(
                     models.FieldCondition(
                         key="metadata.{}".format(key),
@@ -108,6 +111,7 @@ class QdrantDB(BaseVectorDB):
 
         offset = 0
         existing_ids = []
+        metadatas = []
         while offset is not None:
             response = self.client.scroll(
                 collection_name=self.collection_name,
@@ -118,7 +122,8 @@ class QdrantDB(BaseVectorDB):
             offset = response[1]
             for doc in response[0]:
                 existing_ids.append(doc.payload["identifier"])
-        return {"ids": existing_ids}
+                metadatas.append(doc.payload["metadata"])
+        return {"ids": existing_ids, "metadatas": metadatas}
 
     def add(
         self,
@@ -143,7 +148,8 @@ class QdrantDB(BaseVectorDB):
             metadata["text"] = document
             qdrant_ids.append(str(uuid.uuid4()))
             payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
-        for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
+
+        for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"):
             self.client.upsert(
                 collection_name=self.collection_name,
                 points=Batch(
@@ -180,16 +186,17 @@ class QdrantDB(BaseVectorDB):
         keys = set(where.keys() if where is not None else set())
 
         qdrant_must_filters = []
-        if len(keys.intersection(self.metadata_keys)) != 0:
-            for key in keys.intersection(self.metadata_keys):
+        if len(keys) > 0:
+            for key in keys:
                 qdrant_must_filters.append(
                     models.FieldCondition(
-                        key="payload.metadata.{}".format(key),
+                        key="metadata.{}".format(key),
                         match=models.MatchValue(
                             value=where.get(key),
                         ),
                     )
                 )
+
         results = self.client.search(
             collection_name=self.collection_name,
             query_filter=models.Filter(must=qdrant_must_filters),
@@ -228,3 +235,21 @@ class QdrantDB(BaseVectorDB):
             raise TypeError("Collection name must be a string")
         self.config.collection_name = name
         self.collection_name = self._get_or_create_collection()
+
+    @staticmethod
+    def _generate_query(where: dict):
+        must_fields = []
+        for key, value in where.items():
+            must_fields.append(
+                models.FieldCondition(
+                    key=f"metadata.{key}",
+                    match=models.MatchValue(
+                        value=value,
+                    ),
+                )
+            )
+        return models.Filter(must=must_fields)
+
+    def delete(self, where: dict):
+        db_filter = self._generate_query(where)
+        self.client.delete(collection_name=self.collection_name, points_selector=db_filter)

+ 86 - 24
embedchain/vectordb/weaviate.py

@@ -45,6 +45,9 @@ class WeaviateDB(BaseVectorDB):
             auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
             **self.config.extra_params,
         )
+        # Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb.
+        # This is needed to filter data while querying.
+        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
 
         # Call parent init here because embedder is needed
         super().__init__(config=self.config)
@@ -58,7 +61,6 @@ class WeaviateDB(BaseVectorDB):
             raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
 
         self.index_name = self._get_index_name()
-        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
         if not self.client.schema.exists(self.index_name):
             # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
             # The none vectorizer is crucial as we have our own custom embedding function
@@ -127,29 +129,64 @@ class WeaviateDB(BaseVectorDB):
         :return: ids
         :rtype: Set[str]
         """
+        weaviate_where_operands = []
+
+        if ids:
+            for doc_id in ids:
+                weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id})
+
+        keys = set(where.keys() if where is not None else set())
+        if len(keys) > 0:
+            for key in keys:
+                weaviate_where_operands.append(
+                    {
+                        "path": ["metadata", self.index_name + "_metadata", key],
+                        "operator": "Equal",
+                        "valueText": where.get(key),
+                    }
+                )
 
-        if ids is None or len(ids) == 0:
-            return {"ids": []}
+        if len(weaviate_where_operands) == 1:
+            weaviate_where_clause = weaviate_where_operands[0]
+        else:
+            weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
 
         existing_ids = []
+        metadatas = []
         cursor = None
+        offset = 0
         has_iterated_once = False
+        query_metadata_keys = self.metadata_keys.union(keys)
         while cursor is not None or not has_iterated_once:
             has_iterated_once = True
-            results = self._query_with_cursor(
-                self.client.query.get(self.index_name, ["identifier"])
+            results = self._query_with_offset(
+                self.client.query.get(
+                    self.index_name,
+                    [
+                        "identifier",
+                        weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)),
+                    ],
+                )
+                .with_where(weaviate_where_clause)
                 .with_additional(["id"])
-                .with_limit(self.BATCH_SIZE),
-                cursor,
+                .with_limit(limit or self.BATCH_SIZE),
+                offset,
             )
+
             fetched_results = results["data"]["Get"].get(self.index_name, [])
-            if len(fetched_results) == 0:
+            if not fetched_results:
                 break
+
             for result in fetched_results:
                 existing_ids.append(result["identifier"])
+                metadatas.append(result["metadata"][0])
                 cursor = result["_additional"]["id"]
+                offset += 1
 
-        return {"ids": existing_ids}
+            if limit is not None and len(existing_ids) >= limit:
+                break
+
+        return {"ids": existing_ids, "metadatas": metadatas}
 
     def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
         """add data in vector database
@@ -201,21 +238,20 @@ class WeaviateDB(BaseVectorDB):
         query_vector = self.embedder.embedding_fn([input_query])[0]
         keys = set(where.keys() if where is not None else set())
         data_fields = ["text"]
-
+        query_metadata_keys = self.metadata_keys.union(keys)
         if citations:
-            data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
+            data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)))
 
-        if len(keys.intersection(self.metadata_keys)) != 0:
+        if len(keys) > 0:
             weaviate_where_operands = []
             for key in keys:
-                if key in self.metadata_keys:
-                    weaviate_where_operands.append(
-                        {
-                            "path": ["metadata", self.index_name + "_metadata", key],
-                            "operator": "Equal",
-                            "valueText": where.get(key),
-                        }
-                    )
+                weaviate_where_operands.append(
+                    {
+                        "path": ["metadata", self.index_name + "_metadata", key],
+                        "operator": "Equal",
+                        "valueText": where.get(key),
+                    }
+                )
             if len(weaviate_where_operands) == 1:
                 weaviate_where_clause = weaviate_where_operands[0]
             else:
@@ -289,11 +325,37 @@ class WeaviateDB(BaseVectorDB):
         :return: Weaviate index
         :rtype: str
         """
-        return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
+        return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_")
 
     @staticmethod
-    def _query_with_cursor(query, cursor):
-        if cursor is not None:
-            query.with_after(cursor)
+    def _query_with_offset(query, offset):
+        if offset:
+            query.with_offset(offset)
         results = query.do()
         return results
+
+    def _generate_query(self, where: dict):
+        weaviate_where_operands = []
+        for key, value in where.items():
+            weaviate_where_operands.append(
+                {
+                    "path": ["metadata", self.index_name + "_metadata", key],
+                    "operator": "Equal",
+                    "valueText": value,
+                }
+            )
+
+        if len(weaviate_where_operands) == 1:
+            weaviate_where_clause = weaviate_where_operands[0]
+        else:
+            weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
+
+        return weaviate_where_clause
+
+    def delete(self, where: dict):
+        """Delete from database.
+        :param where: to filter data
+        :type where: dict[str, any]
+        """
+        query = self._generate_query(where)
+        self.client.batch.delete_objects(self.index_name, where=query)

+ 3 - 3
tests/vectordb/test_qdrant.py

@@ -56,9 +56,9 @@ class TestQdrantDB(unittest.TestCase):
         App(config=app_config, db=db, embedding_model=embedder)
 
         resp = db.get(ids=[], where={})
-        self.assertEqual(resp, {"ids": []})
+        self.assertEqual(resp, {"ids": [], "metadatas": []})
         resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
-        self.assertEqual(resp2, {"ids": []})
+        self.assertEqual(resp2, {"ids": [], "metadatas": []})
 
     @patch("embedchain.vectordb.qdrant.QdrantClient")
     @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
@@ -119,7 +119,7 @@ class TestQdrantDB(unittest.TestCase):
             query_filter=models.Filter(
                 must=[
                     models.FieldCondition(
-                        key="payload.metadata.doc_id",
+                        key="metadata.doc_id",
                         match=models.MatchValue(
                             value="123",
                         ),