123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- import pytest
- from embedchain.config.vectordb.pinecone import PineconeDBConfig
- from embedchain.vectordb.pinecone import PineconeDB
- @pytest.fixture
- def pinecone_pod_config():
- return PineconeDBConfig(
- collection_name="test_collection",
- api_key="test_api_key",
- vector_dimension=3,
- pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
- )
- @pytest.fixture
- def pinecone_serverless_config():
- return PineconeDBConfig(
- collection_name="test_collection",
- api_key="test_api_key",
- vector_dimension=3,
- serverless_config={
- "cloud": "test_cloud",
- "region": "test_region",
- },
- )
- def test_pinecone_init_without_config(monkeypatch):
- monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
- pinecone_db = PineconeDB()
- assert isinstance(pinecone_db, PineconeDB)
- assert isinstance(pinecone_db.config, PineconeDBConfig)
- assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
- monkeypatch.delenv("PINECONE_API_KEY")
- def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
- pinecone_db = PineconeDB(config=pinecone_pod_config)
- assert isinstance(pinecone_db, PineconeDB)
- assert isinstance(pinecone_db.config, PineconeDBConfig)
- assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
- pinecone_db = PineconeDB(config=pinecone_pod_config)
- assert isinstance(pinecone_db, PineconeDB)
- assert isinstance(pinecone_db.config, PineconeDBConfig)
- assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
- class MockListIndexes:
- def names(self):
- return ["test_collection"]
- class MockPineconeIndex:
- db = []
- def __init__(*args, **kwargs):
- pass
- def upsert(self, chunk, **kwargs):
- self.db.extend([c for c in chunk])
- return
- def delete(self, *args, **kwargs):
- pass
- def query(self, *args, **kwargs):
- return {
- "matches": [
- {
- "metadata": {
- "key": "value",
- "text": "text_1",
- },
- "score": 0.1,
- },
- {
- "metadata": {
- "key": "value",
- "text": "text_2",
- },
- "score": 0.2,
- },
- ]
- }
- def fetch(self, *args, **kwargs):
- return {
- "vectors": {
- "key_1": {
- "metadata": {
- "source": "1",
- }
- },
- "key_2": {
- "metadata": {
- "source": "2",
- }
- },
- }
- }
- def describe_index_stats(self, *args, **kwargs):
- return {"total_vector_count": len(self.db)}
- class MockPineconeClient:
- def __init__(*args, **kwargs):
- pass
- def list_indexes(self):
- return MockListIndexes()
- def create_index(self, *args, **kwargs):
- pass
- def Index(self, *args, **kwargs):
- return MockPineconeIndex()
- def delete_index(self, *args, **kwargs):
- pass
- class MockPinecone:
- def __init__(*args, **kwargs):
- pass
- def Pinecone(*args, **kwargs):
- return MockPineconeClient()
- def PodSpec(*args, **kwargs):
- pass
- def ServerlessSpec(*args, **kwargs):
- pass
- class MockEmbedder:
- def embedding_fn(self, documents):
- return [[1, 1, 1] for d in documents]
- def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
- monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
- monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
- pinecone_db = PineconeDB(config=pinecone_pod_config)
- pinecone_db._setup_pinecone_index()
- assert pinecone_db.client is not None
- assert pinecone_db.config.index_name == "test-collection-3"
- assert pinecone_db.client.list_indexes().names() == ["test_collection"]
- assert pinecone_db.pinecone_index is not None
- pinecone_db = PineconeDB(config=pinecone_serverless_config)
- pinecone_db._setup_pinecone_index()
- assert pinecone_db.client is not None
- assert pinecone_db.config.index_name == "test-collection-3"
- assert pinecone_db.client.list_indexes().names() == ["test_collection"]
- assert pinecone_db.pinecone_index is not None
- def test_get(monkeypatch):
- def mock_pinecone_db():
- monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
- db = PineconeDB()
- db.pinecone_index = MockPineconeIndex()
- return db
- pinecone_db = mock_pinecone_db()
- ids = pinecone_db.get(["key_1", "key_2"])
- assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
- def test_add(monkeypatch):
- def mock_pinecone_db():
- monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
- db = PineconeDB()
- db.pinecone_index = MockPineconeIndex()
- db._set_embedder(MockEmbedder())
- return db
- pinecone_db = mock_pinecone_db()
- pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
- assert pinecone_db.count() == 2
- pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
- assert pinecone_db.count() == 4
- def test_query(monkeypatch):
- def mock_pinecone_db():
- monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
- monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
- db = PineconeDB()
- db.pinecone_index = MockPineconeIndex()
- db._set_embedder(MockEmbedder())
- return db
- pinecone_db = mock_pinecone_db()
- # without citations
- results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
- assert results == ["text_1", "text_2"]
- # with citations
- results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
- assert results == [
- ("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
- ("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
- ]
|