Prechádzať zdrojové kódy

enable using custom Pinecone index name (#1172)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Peter Jausovec 1 rok pred
rodič
commit
446d0975aa

+ 5 - 0
docs/components/vector-databases.mdx

@@ -189,6 +189,11 @@ vectordb:
 
 </CodeGroup>
 
+<br />
+<Note>
+You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
+</Note>
+
 ## Qdrant
 
 In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/).

+ 2 - 0
embedchain/config/vectordb/pinecone.py

@@ -9,6 +9,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
     def __init__(
         self,
         collection_name: Optional[str] = None,
+        index_name: Optional[str] = None,
         dir: Optional[str] = None,
         vector_dimension: int = 1536,
         metric: Optional[str] = "cosine",
@@ -17,4 +18,5 @@ class PineconeDBConfig(BaseVectorDbConfig):
         self.metric = metric
         self.vector_dimension = vector_dimension
         self.extra_params = extra_params
+        self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
         super().__init__(collection_name=collection_name, dir=dir)

+ 7 - 15
embedchain/vectordb/pinecone.py

@@ -53,20 +53,21 @@ class PineconeDB(BaseVectorDB):
         if not self.embedder:
             raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
 
-    # Loads the Pinecone index or creates it if not present.
     def _setup_pinecone_index(self):
+        """
+        Loads the Pinecone index or creates it if not present.
+        """
         pinecone.init(
             api_key=os.environ.get("PINECONE_API_KEY"),
             environment=os.environ.get("PINECONE_ENV"),
             **self.config.extra_params,
         )
-        self.index_name = self._get_index_name()
         indexes = pinecone.list_indexes()
-        if indexes is None or self.index_name not in indexes:
+        if indexes is None or self.config.index_name not in indexes:
             pinecone.create_index(
-                name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
+                name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
             )
-        return pinecone.Index(self.index_name)
+        return pinecone.Index(self.config.index_name)
 
     def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
         """
@@ -193,18 +194,9 @@ class PineconeDB(BaseVectorDB):
         Resets the database. Deletes all embeddings irreversibly.
         """
         # Delete all data from the database
-        pinecone.delete_index(self.index_name)
+        pinecone.delete_index(self.config.index_name)
         self._setup_pinecone_index()
 
-    # Pinecone only allows alphanumeric characters and "-" in the index name
-    def _get_index_name(self) -> str:
-        """Get the Pinecone index for a collection
-
-        :return: Pinecone index
-        :rtype: str
-        """
-        return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
-
     @staticmethod
     def _generate_filter(where: dict):
         query = {}

+ 35 - 2
tests/vectordb/test_pinecone.py

@@ -3,6 +3,7 @@ from unittest.mock import patch
 
 from embedchain import App
 from embedchain.config import AppConfig
+from embedchain.config.vectordb.pinecone import PineconeDBConfig
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.vectordb.pinecone import PineconeDB
 
@@ -100,7 +101,39 @@ class TestPinecone:
         db.reset()
 
         # Assert that the Pinecone client was called to delete the index
-        pinecone_mock.delete_index.assert_called_once_with(db.index_name)
+        pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
 
         # Assert that the index is recreated
-        pinecone_mock.Index.assert_called_with(db.index_name)
+        pinecone_mock.Index.assert_called_with(db.config.index_name)
+
+    @patch("embedchain.vectordb.pinecone.pinecone")
+    def test_custom_index_name_if_it_exists(self, pinecone_mock):
+        """Tests custom index name is used if it exists"""
+        pinecone_mock.list_indexes.return_value = ["custom_index_name"]
+        db_config = PineconeDBConfig(index_name="custom_index_name")
+        _ = PineconeDB(config=db_config)
+
+        pinecone_mock.list_indexes.assert_called_once()
+        pinecone_mock.create_index.assert_not_called()
+        pinecone_mock.Index.assert_called_with("custom_index_name")
+
+    @patch("embedchain.vectordb.pinecone.pinecone")
+    def test_custom_index_name_creation(self, pinecone_mock):
+        """Test custom index name is created if it doesn't exists already"""
+        pinecone_mock.list_indexes.return_value = []
+        db_config = PineconeDBConfig(index_name="custom_index_name")
+        _ = PineconeDB(config=db_config)
+
+        pinecone_mock.list_indexes.assert_called_once()
+        pinecone_mock.create_index.assert_called_once()
+        pinecone_mock.Index.assert_called_with("custom_index_name")
+
+    @patch("embedchain.vectordb.pinecone.pinecone")
+    def test_default_index_name_is_used(self, pinecone_mock):
+        """Test default index name is used if custom index name is not provided"""
+        db_config = PineconeDBConfig(collection_name="my-collection")
+        _ = PineconeDB(config=db_config)
+
+        pinecone_mock.list_indexes.assert_called_once()
+        pinecone_mock.create_index.assert_called_once()
+        pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")