Explorar o código

[Bugfix] fix pinecone db (#1185)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel hai 1 ano
pai
achega
1a654beea4
Modificáronse 1 ficheiros con 37 adicións e 4 borrados
  1. 37 4
      embedchain/vectordb/pinecone.py

+ 37 - 4
embedchain/vectordb/pinecone.py

@@ -1,3 +1,4 @@
+import logging
 import os
 from typing import Optional, Union
 
@@ -79,12 +80,20 @@ class PineconeDB(BaseVectorDB):
         :rtype: Set[str]
         """
         existing_ids = list()
+        metadatas = []
+
         if ids is not None:
             for i in range(0, len(ids), 1000):
                 result = self.client.fetch(ids=ids[i : i + 1000])
-                batch_existing_ids = list(result.get("vectors").keys())
+                vectors = result.get("vectors")
+                batch_existing_ids = list(vectors.keys())
                 existing_ids.extend(batch_existing_ids)
-        return {"ids": existing_ids}
+                metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
+
+        if where is not None:
+            logging.warning("Filtering is not supported by Pinecone")
+
+        return {"ids": existing_ids, "metadatas": metadatas}
 
     def add(
         self,
@@ -114,7 +123,7 @@ class PineconeDB(BaseVectorDB):
                 }
             )
 
-        for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches..."):
+        for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
             self.client.upsert(chunk, **kwargs)
 
     def query(
@@ -140,7 +149,10 @@ class PineconeDB(BaseVectorDB):
         :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
         query_vector = self.embedder.embedding_fn([input_query])[0]
-        data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
+        query_filter = self._generate_filter(where)
+        data = self.client.query(
+            vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs
+        )
         contexts = []
         for doc in data["matches"]:
             metadata = doc["metadata"]
@@ -192,3 +204,24 @@ class PineconeDB(BaseVectorDB):
         :rtype: str
         """
         return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
+
+    @staticmethod
+    def _generate_filter(where: dict):
+        query = {}
+        for k, v in where.items():
+            query[k] = {"$eq": v}
+        return query
+
+    def delete(self, where: dict):
+        """Delete from database.
+        :param ids: list of ids to delete
+        :type ids: list[str]
+        """
+        # Deleting with filters is not supported for `starter` index type.
+        # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
+        db_filter = self._generate_filter(where)
+        try:
+            self.client.delete(filter=db_filter)
+        except Exception as e:
+            print(f"Failed to delete from Pinecone: {e}")
+            return