|
@@ -1,139 +1,225 @@
|
|
|
-from unittest import mock
|
|
|
-from unittest.mock import patch
|
|
|
+import pytest
|
|
|
|
|
|
-from embedchain import App
|
|
|
-from embedchain.config import AppConfig
|
|
|
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
|
|
-from embedchain.embedder.base import BaseEmbedder
|
|
|
from embedchain.vectordb.pinecone import PineconeDB
|
|
|
|
|
|
|
|
|
-class TestPinecone:
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_init(self, pinecone_mock):
|
|
|
- """Test that the PineconeDB can be initialized."""
|
|
|
- # Create a PineconeDB instance
|
|
|
- PineconeDB()
|
|
|
+@pytest.fixture
|
|
|
+def pinecone_pod_config():
|
|
|
+ return PineconeDBConfig(
|
|
|
+ collection_name="test_collection",
|
|
|
+ api_key="test_api_key",
|
|
|
+ vector_dimension=3,
|
|
|
+ pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def pinecone_serverless_config():
|
|
|
+ return PineconeDBConfig(
|
|
|
+ collection_name="test_collection",
|
|
|
+ api_key="test_api_key",
|
|
|
+ vector_dimension=3,
|
|
|
+ serverless_config={
|
|
|
+ "cloud": "test_cloud",
|
|
|
+ "region": "test_region",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def test_pinecone_init_without_config(monkeypatch):
|
|
|
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
|
|
+ pinecone_db = PineconeDB()
|
|
|
+
|
|
|
+ assert isinstance(pinecone_db, PineconeDB)
|
|
|
+ assert isinstance(pinecone_db.config, PineconeDBConfig)
|
|
|
+ assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
|
|
|
+ monkeypatch.delenv("PINECONE_API_KEY")
|
|
|
|
|
|
- # Assert that the Pinecone client was initialized
|
|
|
- pinecone_mock.init.assert_called_once()
|
|
|
- pinecone_mock.list_indexes.assert_called_once()
|
|
|
- pinecone_mock.Index.assert_called_once()
|
|
|
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_set_embedder(self, pinecone_mock):
|
|
|
- """Test that the embedder can be set."""
|
|
|
+def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
|
|
+ pinecone_db = PineconeDB(config=pinecone_pod_config)
|
|
|
|
|
|
- # Set the embedder
|
|
|
- embedder = BaseEmbedder()
|
|
|
+ assert isinstance(pinecone_db, PineconeDB)
|
|
|
+ assert isinstance(pinecone_db.config, PineconeDBConfig)
|
|
|
|
|
|
- # Create a PineconeDB instance
|
|
|
- db = PineconeDB()
|
|
|
- app_config = AppConfig(collect_metrics=False)
|
|
|
- App(config=app_config, db=db, embedding_model=embedder)
|
|
|
+ assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
|
|
|
+
|
|
|
+ pinecone_db = PineconeDB(config=pinecone_pod_config)
|
|
|
+
|
|
|
+ assert isinstance(pinecone_db, PineconeDB)
|
|
|
+ assert isinstance(pinecone_db.config, PineconeDBConfig)
|
|
|
+
|
|
|
+ assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
|
|
|
+
|
|
|
+
|
|
|
+class MockListIndexes:
|
|
|
+ def names(self):
|
|
|
+ return ["test_collection"]
|
|
|
+
|
|
|
+
|
|
|
+class MockPineconeIndex:
|
|
|
+ db = []
|
|
|
+
|
|
|
+ def __init__(*args, **kwargs):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def upsert(self, chunk, **kwargs):
|
|
|
+ self.db.extend([c for c in chunk])
|
|
|
+ return
|
|
|
+
|
|
|
+ def delete(self, *args, **kwargs):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def query(self, *args, **kwargs):
|
|
|
+ return {
|
|
|
+ "matches": [
|
|
|
+ {
|
|
|
+ "metadata": {
|
|
|
+ "key": "value",
|
|
|
+ "text": "text_1",
|
|
|
+ },
|
|
|
+ "score": 0.1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "metadata": {
|
|
|
+ "key": "value",
|
|
|
+ "text": "text_2",
|
|
|
+ },
|
|
|
+ "score": 0.2,
|
|
|
+ },
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+ def fetch(self, *args, **kwargs):
|
|
|
+ return {
|
|
|
+ "vectors": {
|
|
|
+ "key_1": {
|
|
|
+ "metadata": {
|
|
|
+ "source": "1",
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "key_2": {
|
|
|
+ "metadata": {
|
|
|
+ "source": "2",
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ def describe_index_stats(self, *args, **kwargs):
|
|
|
+ return {"total_vector_count": len(self.db)}
|
|
|
+
|
|
|
+
|
|
|
+class MockPineconeClient:
|
|
|
+ def __init__(*args, **kwargs):
|
|
|
+ pass
|
|
|
|
|
|
- # Assert that the embedder was set
|
|
|
- assert db.embedder == embedder
|
|
|
- pinecone_mock.init.assert_called_once()
|
|
|
+ def list_indexes(self):
|
|
|
+ return MockListIndexes()
|
|
|
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_add_documents(self, pinecone_mock):
|
|
|
- """Test that documents can be added to the database."""
|
|
|
- pinecone_client_mock = pinecone_mock.Index.return_value
|
|
|
+ def create_index(self, *args, **kwargs):
|
|
|
+ pass
|
|
|
|
|
|
- embedding_function = mock.Mock()
|
|
|
- base_embedder = BaseEmbedder()
|
|
|
- base_embedder.set_embedding_fn(embedding_function)
|
|
|
- embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
|
|
|
+ def Index(self, *args, **kwargs):
|
|
|
+ return MockPineconeIndex()
|
|
|
|
|
|
- # Create a PineconeDb instance
|
|
|
+ def delete_index(self, *args, **kwargs):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class MockPinecone:
|
|
|
+ def __init__(*args, **kwargs):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def Pinecone(*args, **kwargs):
|
|
|
+ return MockPineconeClient()
|
|
|
+
|
|
|
+ def PodSpec(*args, **kwargs):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def ServerlessSpec(*args, **kwargs):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class MockEmbedder:
|
|
|
+ def embedding_fn(self, documents):
|
|
|
+ return [[1, 1, 1] for d in documents]
|
|
|
+
|
|
|
+
|
|
|
+def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
|
|
|
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
|
|
+ pinecone_db = PineconeDB(config=pinecone_pod_config)
|
|
|
+ pinecone_db._setup_pinecone_index()
|
|
|
+
|
|
|
+ assert pinecone_db.client is not None
|
|
|
+ assert pinecone_db.config.index_name == "test-collection-3"
|
|
|
+ assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
|
|
+ assert pinecone_db.pinecone_index is not None
|
|
|
+
|
|
|
+ pinecone_db = PineconeDB(config=pinecone_serverless_config)
|
|
|
+ pinecone_db._setup_pinecone_index()
|
|
|
+
|
|
|
+ assert pinecone_db.client is not None
|
|
|
+ assert pinecone_db.config.index_name == "test-collection-3"
|
|
|
+ assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
|
|
+ assert pinecone_db.pinecone_index is not None
|
|
|
+
|
|
|
+
|
|
|
+def test_get(monkeypatch):
|
|
|
+ def mock_pinecone_db():
|
|
|
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
|
|
db = PineconeDB()
|
|
|
- app_config = AppConfig(collect_metrics=False)
|
|
|
- App(config=app_config, db=db, embedding_model=base_embedder)
|
|
|
-
|
|
|
- # Add some documents to the database
|
|
|
- documents = ["This is a document.", "This is another document."]
|
|
|
- metadatas = [{}, {}]
|
|
|
- ids = ["doc1", "doc2"]
|
|
|
- db.add(documents, metadatas, ids)
|
|
|
-
|
|
|
- expected_pinecone_upsert_args = [
|
|
|
- {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
|
|
|
- {"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
|
|
|
- ]
|
|
|
- # Assert that the Pinecone client was called to upsert the documents
|
|
|
- pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
|
|
|
-
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_query_documents(self, pinecone_mock):
|
|
|
- """Test that documents can be queried from the database."""
|
|
|
- pinecone_client_mock = pinecone_mock.Index.return_value
|
|
|
-
|
|
|
- embedding_function = mock.Mock()
|
|
|
- base_embedder = BaseEmbedder()
|
|
|
- base_embedder.set_embedding_fn(embedding_function)
|
|
|
- vectors = [[0, 0, 0]]
|
|
|
- embedding_function.return_value = vectors
|
|
|
- # Create a PineconeDB instance
|
|
|
+ db.pinecone_index = MockPineconeIndex()
|
|
|
+ return db
|
|
|
+
|
|
|
+ pinecone_db = mock_pinecone_db()
|
|
|
+ ids = pinecone_db.get(["key_1", "key_2"])
|
|
|
+ assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
|
|
|
+
|
|
|
+
|
|
|
+def test_add(monkeypatch):
|
|
|
+ def mock_pinecone_db():
|
|
|
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
|
|
db = PineconeDB()
|
|
|
- app_config = AppConfig(collect_metrics=False)
|
|
|
- App(config=app_config, db=db, embedding_model=base_embedder)
|
|
|
-
|
|
|
- # Query the database for documents that are similar to "document"
|
|
|
- input_query = ["document"]
|
|
|
- n_results = 1
|
|
|
- db.query(input_query, n_results, where={})
|
|
|
-
|
|
|
- # Assert that the Pinecone client was called to query the database
|
|
|
- pinecone_client_mock.query.assert_called_once_with(
|
|
|
- vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
|
|
|
- )
|
|
|
-
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_reset(self, pinecone_mock):
|
|
|
- """Test that the database can be reset."""
|
|
|
- # Create a PineconeDb instance
|
|
|
+ db.pinecone_index = MockPineconeIndex()
|
|
|
+ db._set_embedder(MockEmbedder())
|
|
|
+ return db
|
|
|
+
|
|
|
+ pinecone_db = mock_pinecone_db()
|
|
|
+ pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
|
|
|
+ assert pinecone_db.count() == 2
|
|
|
+
|
|
|
+ pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
|
|
|
+ assert pinecone_db.count() == 4
|
|
|
+
|
|
|
+
|
|
|
+def test_query(monkeypatch):
|
|
|
+ def mock_pinecone_db():
|
|
|
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
|
|
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
|
|
db = PineconeDB()
|
|
|
- app_config = AppConfig(collect_metrics=False)
|
|
|
- App(config=app_config, db=db, embedding_model=BaseEmbedder())
|
|
|
-
|
|
|
- # Reset the database
|
|
|
- db.reset()
|
|
|
-
|
|
|
- # Assert that the Pinecone client was called to delete the index
|
|
|
- pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
|
|
|
-
|
|
|
- # Assert that the index is recreated
|
|
|
- pinecone_mock.Index.assert_called_with(db.config.index_name)
|
|
|
-
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_custom_index_name_if_it_exists(self, pinecone_mock):
|
|
|
- """Tests custom index name is used if it exists"""
|
|
|
- pinecone_mock.list_indexes.return_value = ["custom_index_name"]
|
|
|
- db_config = PineconeDBConfig(index_name="custom_index_name")
|
|
|
- _ = PineconeDB(config=db_config)
|
|
|
-
|
|
|
- pinecone_mock.list_indexes.assert_called_once()
|
|
|
- pinecone_mock.create_index.assert_not_called()
|
|
|
- pinecone_mock.Index.assert_called_with("custom_index_name")
|
|
|
-
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_custom_index_name_creation(self, pinecone_mock):
|
|
|
- """Test custom index name is created if it doesn't exists already"""
|
|
|
- pinecone_mock.list_indexes.return_value = []
|
|
|
- db_config = PineconeDBConfig(index_name="custom_index_name")
|
|
|
- _ = PineconeDB(config=db_config)
|
|
|
-
|
|
|
- pinecone_mock.list_indexes.assert_called_once()
|
|
|
- pinecone_mock.create_index.assert_called_once()
|
|
|
- pinecone_mock.Index.assert_called_with("custom_index_name")
|
|
|
-
|
|
|
- @patch("embedchain.vectordb.pinecone.pinecone")
|
|
|
- def test_default_index_name_is_used(self, pinecone_mock):
|
|
|
- """Test default index name is used if custom index name is not provided"""
|
|
|
- db_config = PineconeDBConfig(collection_name="my-collection")
|
|
|
- _ = PineconeDB(config=db_config)
|
|
|
-
|
|
|
- pinecone_mock.list_indexes.assert_called_once()
|
|
|
- pinecone_mock.create_index.assert_called_once()
|
|
|
- pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")
|
|
|
+ db.pinecone_index = MockPineconeIndex()
|
|
|
+ db._set_embedder(MockEmbedder())
|
|
|
+ return db
|
|
|
+
|
|
|
+ pinecone_db = mock_pinecone_db()
|
|
|
+ # without citations
|
|
|
+ results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
|
|
|
+ assert results == ["text_1", "text_2"]
|
|
|
+ # with citations
|
|
|
+ results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
|
|
|
+ assert results == [
|
|
|
+ ("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
|
|
|
+ ("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
|
|
|
+ ]
|