Bläddra i källkod

[Feature] Pinecone Vector DB support (#723)

Rupesh Bansal 1 år sedan
förälder
incheckning
a7a61fae1d

+ 18 - 0
embedchain/config/vectordb/pinecone.py

@@ -0,0 +1,18 @@
+from typing import Optional
+
+from embedchain.config.vectordb.base import BaseVectorDbConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class PineconeDbConfig(BaseVectorDbConfig):
+    def __init__(
+        self,
+        collection_name: Optional[str] = None,
+        dir: Optional[str] = None,
+        dimension: Optional[int] = 1536,
+        metric: Optional[str] = "cosine",
+    ):
+        self.dimension = dimension
+        self.metric = metric
+        super().__init__(collection_name=collection_name, dir=dir)

+ 0 - 2
embedchain/embedchain.py

@@ -339,7 +339,6 @@ class EmbedChain(JSONSerializable):
         metadatas = embeddings_data["metadatas"]
         ids = embeddings_data["ids"]
         new_doc_id = embeddings_data["doc_id"]
-
         if existing_doc_id and existing_doc_id == new_doc_id:
             print("Doc content has not changed. Skipping creating chunks and embeddings")
             return [], [], [], 0
@@ -404,7 +403,6 @@ class EmbedChain(JSONSerializable):
             skip_embedding=(chunker.data_type == DataType.IMAGES),
         )
         count_new_chunks = self.db.count() - chunks_before_addition
-        print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
         return list(documents), metadatas, ids, count_new_chunks
 
     def _format_result(self, results):

+ 2 - 4
embedchain/vectordb/elasticsearch.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional, Set
+from typing import Any, Dict, List, Optional
 
 try:
     from elasticsearch import Elasticsearch
@@ -74,9 +74,7 @@ class ElasticsearchDB(BaseVectorDB):
     def _get_or_create_collection(self, name):
         """Note: nothing to return here. Discuss later"""
 
-    def get(
-        self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
-    ) -> Set[str]:
+    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
         """
         Get existing doc ids present in vector database
 

+ 180 - 0
embedchain/vectordb/pineconedb.py

@@ -0,0 +1,180 @@
+import copy
+import os
+from typing import Dict, List, Optional
+
+try:
+    import pinecone
+except ImportError:
+    raise ImportError(
+        "Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
+    ) from None
+
+from embedchain.config.vectordb.pinecone import PineconeDbConfig
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.vectordb.base import BaseVectorDB
+
+
+@register_deserializable
+class PineconeDb(BaseVectorDB):
+    BATCH_SIZE = 100
+
+    """
+    Pinecone as vector database
+    """
+
+    def __init__(
+        self,
+        config: Optional[PineconeDbConfig] = None,
+    ):
+        """Pinecone as vector database.
+
+        :param config: Pinecone database config, defaults to None
+        :type config: PineconeDbConfig, optional
+        :raises ValueError: No config provided
+        """
+        if config is None:
+            self.config = PineconeDbConfig()
+        else:
+            if not isinstance(config, PineconeDbConfig):
+                raise TypeError(
+                    "config is not a `PineconeDbConfig` instance. "
+                    "Please make sure the type is right and that you are passing an instance."
+                )
+            self.config = config
+        self.client = self._setup_pinecone_index()
+        # Call parent init here because embedder is needed
+        super().__init__(config=self.config)
+
+    def _initialize(self):
+        """
+        This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
+        """
+        if not self.embedder:
+            raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
+
+    # Loads the Pinecone index or creates it if not present.
+    def _setup_pinecone_index(self):
+        pinecone.init(
+            api_key=os.environ.get("PINECONE_API_KEY"),
+            environment=os.environ.get("PINECONE_ENV"),
+        )
+        self.index_name = self._get_index_name()
+        indexes = pinecone.list_indexes()
+        if indexes is None or self.index_name not in indexes:
+            pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
+        return pinecone.Index(self.index_name)
+
+    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+        """
+        Get existing doc ids present in vector database
+
+        :param ids: _list of doc ids to check for existence
+        :type ids: List[str]
+        :param where: to filter data
+        :type where: Dict[str, any]
+        :return: ids
+        :rtype: Set[str]
+        """
+        existing_ids = list()
+        if ids is not None:
+            for i in range(0, len(ids), 1000):
+                result = self.client.fetch(ids=ids[i : i + 1000])
+                batch_existing_ids = list(result.get("vectors").keys())
+                existing_ids.extend(batch_existing_ids)
+
+        return {"ids": existing_ids}
+
+    def add(
+        self,
+        embeddings: List[List[float]],
+        documents: List[str],
+        metadatas: List[object],
+        ids: List[str],
+        skip_embedding: bool,
+    ):
+        """add data in vector database
+
+        :param documents: list of texts to add
+        :type documents: List[str]
+        :param metadatas: list of metadata associated with docs
+        :type metadatas: List[object]
+        :param ids: ids of docs
+        :type ids: List[str]
+        """
+        docs = []
+        if embeddings is None:
+            embeddings = self.embedder.embedding_fn(documents)
+        for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
+            metadata["text"] = text
+            docs.append(
+                {
+                    "id": id,
+                    "values": embedding,
+                    "metadata": copy.deepcopy(metadata),
+                }
+            )
+
+        for i in range(0, len(docs), self.BATCH_SIZE):
+            self.client.upsert(docs[i : i + self.BATCH_SIZE])
+
+    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+        """
+        query contents from vector database based on vector similarity
+
+        :param input_query: list of query string
+        :type input_query: List[str]
+        :param n_results: no of similar documents to fetch from database
+        :type n_results: int
+        :param where: Optional. to filter data
+        :type where: Dict[str, any]
+        :return: Database contents that are the result of the query
+        :rtype: List[str]
+        """
+        if not skip_embedding:
+            query_vector = self.embedder.embedding_fn([input_query])[0]
+        else:
+            query_vector = input_query
+        contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
+        embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
+        return embeddings
+
+    def set_collection_name(self, name: str):
+        """
+        Set the name of the collection. A collection is an isolated space for vectors.
+
+        :param name: Name of the collection.
+        :type name: str
+        """
+        if not isinstance(name, str):
+            raise TypeError("Collection name must be a string")
+        self.config.collection_name = name
+
+    def count(self) -> int:
+        """
+        Count number of documents/chunks embedded in the database.
+
+        :return: number of documents
+        :rtype: int
+        """
+        return self.client.describe_index_stats()["total_vector_count"]
+
+    def _get_or_create_db(self):
+        """Called during initialization"""
+        return self.client
+
+    def reset(self):
+        """
+        Resets the database. Deletes all embeddings irreversibly.
+        """
+        # Delete all data from the database
+        pinecone.delete_index(self.index_name)
+        self._setup_pinecone_index()
+
+    # Pinecone only allows alphanumeric characters and "-" in the index name
+    def _get_index_name(self) -> str:
+        """Get the Pinecone index for a collection
+
+        :return: Pinecone index
+        :rtype: str
+        """
+        return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")

+ 2 - 0
pyproject.toml

@@ -112,6 +112,7 @@ discord = { version = "^2.3.2", optional = true }
 slack-sdk = { version = "3.21.3", optional = true }
 cohere = { version = "^4.27", optional= true }
 docx2txt = "^0.8"
+pinecone-client = "^2.2.4"
 unstructured = {extras = ["local-inference"], version = "^0.10.18"}
 pillow = { version = "10.0.1", optional = true }
 torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
@@ -142,6 +143,7 @@ poe = ["fastapi-poe"]
 discord = ["discord"]
 slack = ["slack-sdk", "flask"]
 whatsapp = ["twilio", "flask"]
+pinecone = ["pinecone-client"]
 images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
 huggingface_hub=["huggingface_hub"]
 cohere = ["cohere"]

+ 106 - 0
tests/vectordb/test_pinecone_db.py

@@ -0,0 +1,106 @@
+from unittest import mock
+from unittest.mock import patch
+
+from embedchain import App
+from embedchain.config import AppConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.vectordb.pineconedb import PineconeDb
+
+
+class TestPineconeDb:
+    @patch("embedchain.vectordb.pineconedb.pinecone")
+    def test_init(self, pinecone_mock):
+        """Test that the PineconeDb can be initialized."""
+        # Create a PineconeDb instance
+        PineconeDb()
+
+        # Assert that the Pinecone client was initialized
+        pinecone_mock.init.assert_called_once()
+        pinecone_mock.list_indexes.assert_called_once()
+        pinecone_mock.Index.assert_called_once()
+
+    @patch("embedchain.vectordb.pineconedb.pinecone")
+    def test_set_embedder(self, pinecone_mock):
+        """Test that the embedder can be set."""
+
+        # Set the embedder
+        embedder = BaseEmbedder()
+
+        # Create a PineconeDb instance
+        db = PineconeDb()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+
+        # Assert that the embedder was set
+        assert db.embedder == embedder
+        pinecone_mock.init.assert_called_once()
+
+    @patch("embedchain.vectordb.pineconedb.pinecone")
+    def test_add_documents(self, pinecone_mock):
+        """Test that documents can be added to the database."""
+        pinecone_client_mock = pinecone_mock.Index.return_value
+
+        embedding_function = mock.Mock()
+        base_embedder = BaseEmbedder()
+        base_embedder.set_embedding_fn(embedding_function)
+        vectors = [[0, 0, 0], [1, 1, 1]]
+        embedding_function.return_value = vectors
+        # Create a PineconeDb instance
+        db = PineconeDb()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=base_embedder)
+
+        # Add some documents to the database
+        documents = ["This is a document.", "This is another document."]
+        metadatas = [{}, {}]
+        ids = ["doc1", "doc2"]
+        db.add(vectors, documents, metadatas, ids, True)
+
+        expected_pinecone_upsert_args = [
+            {"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]},
+            {"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]},
+        ]
+        # Assert that the Pinecone client was called to upsert the documents
+        pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
+
+    @patch("embedchain.vectordb.pineconedb.pinecone")
+    def test_query_documents(self, pinecone_mock):
+        """Test that documents can be queried from the database."""
+        pinecone_client_mock = pinecone_mock.Index.return_value
+
+        embedding_function = mock.Mock()
+        base_embedder = BaseEmbedder()
+        base_embedder.set_embedding_fn(embedding_function)
+        vectors = [[0, 0, 0]]
+        embedding_function.return_value = vectors
+        # Create a PineconeDb instance
+        db = PineconeDb()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=base_embedder)
+
+        # Query the database for documents that are similar to "document"
+        input_query = ["document"]
+        n_results = 1
+        db.query(input_query, n_results, where={}, skip_embedding=False)
+
+        # Assert that the Pinecone client was called to query the database
+        pinecone_client_mock.query.assert_called_once_with(
+            vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
+        )
+
+    @patch("embedchain.vectordb.pineconedb.pinecone")
+    def test_reset(self, pinecone_mock):
+        """Test that the database can be reset."""
+        # Create a PineconeDb instance
+        db = PineconeDb()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=BaseEmbedder())
+
+        # Reset the database
+        db.reset()
+
+        # Assert that the Pinecone client was called to delete the index
+        pinecone_mock.delete_index.assert_called_once_with(db.index_name)
+
+        # Assert that the index is recreated
+        pinecone_mock.Index.assert_called_with(db.index_name)