test_pinecone.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from unittest import mock
  2. from unittest.mock import patch
  3. from embedchain import App
  4. from embedchain.config import AppConfig
  5. from embedchain.embedder.base import BaseEmbedder
  6. from embedchain.vectordb.pinecone import PineconeDB
  7. class TestPinecone:
  8. @patch("embedchain.vectordb.pinecone.pinecone")
  9. def test_init(self, pinecone_mock):
  10. """Test that the PineconeDB can be initialized."""
  11. # Create a PineconeDB instance
  12. PineconeDB()
  13. # Assert that the Pinecone client was initialized
  14. pinecone_mock.init.assert_called_once()
  15. pinecone_mock.list_indexes.assert_called_once()
  16. pinecone_mock.Index.assert_called_once()
  17. @patch("embedchain.vectordb.pinecone.pinecone")
  18. def test_set_embedder(self, pinecone_mock):
  19. """Test that the embedder can be set."""
  20. # Set the embedder
  21. embedder = BaseEmbedder()
  22. # Create a PineconeDB instance
  23. db = PineconeDB()
  24. app_config = AppConfig(collect_metrics=False)
  25. App(config=app_config, db=db, embedder=embedder)
  26. # Assert that the embedder was set
  27. assert db.embedder == embedder
  28. pinecone_mock.init.assert_called_once()
  29. @patch("embedchain.vectordb.pinecone.pinecone")
  30. def test_add_documents(self, pinecone_mock):
  31. """Test that documents can be added to the database."""
  32. pinecone_client_mock = pinecone_mock.Index.return_value
  33. embedding_function = mock.Mock()
  34. base_embedder = BaseEmbedder()
  35. base_embedder.set_embedding_fn(embedding_function)
  36. vectors = [[0, 0, 0], [1, 1, 1]]
  37. embedding_function.return_value = vectors
  38. # Create a PineconeDb instance
  39. db = PineconeDB()
  40. app_config = AppConfig(collect_metrics=False)
  41. App(config=app_config, db=db, embedder=base_embedder)
  42. # Add some documents to the database
  43. documents = ["This is a document.", "This is another document."]
  44. metadatas = [{}, {}]
  45. ids = ["doc1", "doc2"]
  46. db.add(vectors, documents, metadatas, ids, True)
  47. expected_pinecone_upsert_args = [
  48. {"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]},
  49. {"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]},
  50. ]
  51. # Assert that the Pinecone client was called to upsert the documents
  52. pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
  53. @patch("embedchain.vectordb.pinecone.pinecone")
  54. def test_query_documents(self, pinecone_mock):
  55. """Test that documents can be queried from the database."""
  56. pinecone_client_mock = pinecone_mock.Index.return_value
  57. embedding_function = mock.Mock()
  58. base_embedder = BaseEmbedder()
  59. base_embedder.set_embedding_fn(embedding_function)
  60. vectors = [[0, 0, 0]]
  61. embedding_function.return_value = vectors
  62. # Create a PineconeDB instance
  63. db = PineconeDB()
  64. app_config = AppConfig(collect_metrics=False)
  65. App(config=app_config, db=db, embedder=base_embedder)
  66. # Query the database for documents that are similar to "document"
  67. input_query = ["document"]
  68. n_results = 1
  69. db.query(input_query, n_results, where={}, skip_embedding=False)
  70. # Assert that the Pinecone client was called to query the database
  71. pinecone_client_mock.query.assert_called_once_with(
  72. vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
  73. )
  74. @patch("embedchain.vectordb.pinecone.pinecone")
  75. def test_reset(self, pinecone_mock):
  76. """Test that the database can be reset."""
  77. # Create a PineconeDb instance
  78. db = PineconeDB()
  79. app_config = AppConfig(collect_metrics=False)
  80. App(config=app_config, db=db, embedder=BaseEmbedder())
  81. # Reset the database
  82. db.reset()
  83. # Assert that the Pinecone client was called to delete the index
  84. pinecone_mock.delete_index.assert_called_once_with(db.index_name)
  85. # Assert that the index is recreated
  86. pinecone_mock.Index.assert_called_with(db.index_name)