test_pinecone.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import pytest
  2. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  3. from embedchain.vectordb.pinecone import PineconeDB
  4. @pytest.fixture
  5. def pinecone_pod_config():
  6. return PineconeDBConfig(
  7. collection_name="test_collection",
  8. api_key="test_api_key",
  9. vector_dimension=3,
  10. pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
  11. )
  12. @pytest.fixture
  13. def pinecone_serverless_config():
  14. return PineconeDBConfig(
  15. collection_name="test_collection",
  16. api_key="test_api_key",
  17. vector_dimension=3,
  18. serverless_config={
  19. "cloud": "test_cloud",
  20. "region": "test_region",
  21. },
  22. )
  23. def test_pinecone_init_without_config(monkeypatch):
  24. monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
  25. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
  26. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
  27. pinecone_db = PineconeDB()
  28. assert isinstance(pinecone_db, PineconeDB)
  29. assert isinstance(pinecone_db.config, PineconeDBConfig)
  30. assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
  31. monkeypatch.delenv("PINECONE_API_KEY")
  32. def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
  33. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
  34. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
  35. pinecone_db = PineconeDB(config=pinecone_pod_config)
  36. assert isinstance(pinecone_db, PineconeDB)
  37. assert isinstance(pinecone_db.config, PineconeDBConfig)
  38. assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
  39. pinecone_db = PineconeDB(config=pinecone_pod_config)
  40. assert isinstance(pinecone_db, PineconeDB)
  41. assert isinstance(pinecone_db.config, PineconeDBConfig)
  42. assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
  43. class MockListIndexes:
  44. def names(self):
  45. return ["test_collection"]
  46. class MockPineconeIndex:
  47. db = []
  48. def __init__(*args, **kwargs):
  49. pass
  50. def upsert(self, chunk, **kwargs):
  51. self.db.extend([c for c in chunk])
  52. return
  53. def delete(self, *args, **kwargs):
  54. pass
  55. def query(self, *args, **kwargs):
  56. return {
  57. "matches": [
  58. {
  59. "metadata": {
  60. "key": "value",
  61. "text": "text_1",
  62. },
  63. "score": 0.1,
  64. },
  65. {
  66. "metadata": {
  67. "key": "value",
  68. "text": "text_2",
  69. },
  70. "score": 0.2,
  71. },
  72. ]
  73. }
  74. def fetch(self, *args, **kwargs):
  75. return {
  76. "vectors": {
  77. "key_1": {
  78. "metadata": {
  79. "source": "1",
  80. }
  81. },
  82. "key_2": {
  83. "metadata": {
  84. "source": "2",
  85. }
  86. },
  87. }
  88. }
  89. def describe_index_stats(self, *args, **kwargs):
  90. return {"total_vector_count": len(self.db)}
  91. class MockPineconeClient:
  92. def __init__(*args, **kwargs):
  93. pass
  94. def list_indexes(self):
  95. return MockListIndexes()
  96. def create_index(self, *args, **kwargs):
  97. pass
  98. def Index(self, *args, **kwargs):
  99. return MockPineconeIndex()
  100. def delete_index(self, *args, **kwargs):
  101. pass
  102. class MockPinecone:
  103. def __init__(*args, **kwargs):
  104. pass
  105. def Pinecone(*args, **kwargs):
  106. return MockPineconeClient()
  107. def PodSpec(*args, **kwargs):
  108. pass
  109. def ServerlessSpec(*args, **kwargs):
  110. pass
  111. class MockEmbedder:
  112. def embedding_fn(self, documents):
  113. return [[1, 1, 1] for d in documents]
  114. def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
  115. monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
  116. monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
  117. pinecone_db = PineconeDB(config=pinecone_pod_config)
  118. pinecone_db._setup_pinecone_index()
  119. assert pinecone_db.client is not None
  120. assert pinecone_db.config.index_name == "test-collection-3"
  121. assert pinecone_db.client.list_indexes().names() == ["test_collection"]
  122. assert pinecone_db.pinecone_index is not None
  123. pinecone_db = PineconeDB(config=pinecone_serverless_config)
  124. pinecone_db._setup_pinecone_index()
  125. assert pinecone_db.client is not None
  126. assert pinecone_db.config.index_name == "test-collection-3"
  127. assert pinecone_db.client.list_indexes().names() == ["test_collection"]
  128. assert pinecone_db.pinecone_index is not None
  129. def test_get(monkeypatch):
  130. def mock_pinecone_db():
  131. monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
  132. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
  133. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
  134. db = PineconeDB()
  135. db.pinecone_index = MockPineconeIndex()
  136. return db
  137. pinecone_db = mock_pinecone_db()
  138. ids = pinecone_db.get(["key_1", "key_2"])
  139. assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
  140. def test_add(monkeypatch):
  141. def mock_pinecone_db():
  142. monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
  143. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
  144. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
  145. db = PineconeDB()
  146. db.pinecone_index = MockPineconeIndex()
  147. db._set_embedder(MockEmbedder())
  148. return db
  149. pinecone_db = mock_pinecone_db()
  150. pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
  151. assert pinecone_db.count() == 2
  152. pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
  153. assert pinecone_db.count() == 4
  154. def test_query(monkeypatch):
  155. def mock_pinecone_db():
  156. monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
  157. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
  158. monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
  159. db = PineconeDB()
  160. db.pinecone_index = MockPineconeIndex()
  161. db._set_embedder(MockEmbedder())
  162. return db
  163. pinecone_db = mock_pinecone_db()
  164. # without citations
  165. results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
  166. assert results == ["text_1", "text_2"]
  167. # with citations
  168. results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
  169. assert results == [
  170. ("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
  171. ("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
  172. ]