test_zilliz_db.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # ruff: noqa: E501
  2. import os
  3. from unittest import mock
  4. from unittest.mock import Mock, patch
  5. import pytest
  6. from embedchain.config import ZillizDBConfig
  7. from embedchain.vectordb.zilliz import ZillizVectorDB
  8. # to run tests, provide the URI and TOKEN in .env file
  9. class TestZillizVectorDBConfig:
  10. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  11. def test_init_with_uri_and_token(self):
  12. """
  13. Test if the `ZillizVectorDBConfig` instance is initialized with the correct uri and token values.
  14. """
  15. # Create a ZillizDBConfig instance with mocked values
  16. expected_uri = "mocked_uri"
  17. expected_token = "mocked_token"
  18. db_config = ZillizDBConfig()
  19. # Assert that the values in the ZillizVectorDB instance match the mocked values
  20. assert db_config.uri == expected_uri
  21. assert db_config.token == expected_token
  22. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  23. def test_init_without_uri(self):
  24. """
  25. Test if the `ZillizVectorDBConfig` instance throws an error when no URI found.
  26. """
  27. try:
  28. del os.environ["ZILLIZ_CLOUD_URI"]
  29. except KeyError:
  30. pass
  31. with pytest.raises(AttributeError):
  32. ZillizDBConfig()
  33. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  34. def test_init_without_token(self):
  35. """
  36. Test if the `ZillizVectorDBConfig` instance throws an error when no Token found.
  37. """
  38. try:
  39. del os.environ["ZILLIZ_CLOUD_TOKEN"]
  40. except KeyError:
  41. pass
  42. # Test if an exception is raised when ZILLIZ_CLOUD_TOKEN is missing
  43. with pytest.raises(AttributeError):
  44. ZillizDBConfig()
  45. class TestZillizVectorDB:
  46. @pytest.fixture
  47. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  48. def mock_config(self, mocker):
  49. return mocker.Mock(spec=ZillizDBConfig())
  50. @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
  51. @patch("embedchain.vectordb.zilliz.connections.connect", autospec=True)
  52. def test_zilliz_vector_db_setup(self, mock_connect, mock_client, mock_config):
  53. """
  54. Test if the `ZillizVectorDB` instance is initialized with the correct uri and token values.
  55. """
  56. # Create an instance of ZillizVectorDB with the mock config
  57. # zilliz_db = ZillizVectorDB(config=mock_config)
  58. ZillizVectorDB(config=mock_config)
  59. # Assert that the MilvusClient and connections.connect were called
  60. mock_client.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
  61. mock_connect.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
  62. class TestZillizDBCollection:
  63. @pytest.fixture
  64. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  65. def mock_config(self, mocker):
  66. return mocker.Mock(spec=ZillizDBConfig())
  67. @pytest.fixture
  68. def mock_embedder(self, mocker):
  69. return mocker.Mock()
  70. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  71. def test_init_with_default_collection(self):
  72. """
  73. Test if the `ZillizVectorDB` instance is initialized with the correct default collection name.
  74. """
  75. # Create a ZillizDBConfig instance
  76. db_config = ZillizDBConfig()
  77. assert db_config.collection_name == "embedchain_store"
  78. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  79. def test_init_with_custom_collection(self):
  80. """
  81. Test if the `ZillizVectorDB` instance is initialized with the correct custom collection name.
  82. """
  83. # Create a ZillizDBConfig instance with mocked values
  84. expected_collection = "test_collection"
  85. db_config = ZillizDBConfig(collection_name="test_collection")
  86. assert db_config.collection_name == expected_collection
  87. @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
  88. @patch("embedchain.vectordb.zilliz.connections", autospec=True)
  89. def test_query_with_skip_embedding(self, mock_connect, mock_client, mock_config):
  90. """
  91. Test if the `ZillizVectorDB` instance is takes in the query with skip_embeddings.
  92. """
  93. # Create an instance of ZillizVectorDB with mock config
  94. zilliz_db = ZillizVectorDB(config=mock_config)
  95. # Add a 'collection' attribute to the ZillizVectorDB instance for testing
  96. zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
  97. assert zilliz_db.client == mock_client()
  98. # Mock the MilvusClient search method
  99. with patch.object(zilliz_db.client, "search") as mock_search:
  100. # Mock the search result
  101. mock_search.return_value = [
  102. [
  103. {
  104. "distance": 0.5,
  105. "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
  106. }
  107. ]
  108. ]
  109. # Call the query method with skip_embedding=True
  110. query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
  111. # Assert that MilvusClient.search was called with the correct parameters
  112. mock_search.assert_called_with(
  113. collection_name=mock_config.collection_name,
  114. data=["query_text"],
  115. limit=1,
  116. output_fields=["*"],
  117. )
  118. # Assert that the query result matches the expected result
  119. assert query_result == ["result_doc"]
  120. query_result_with_citations = zilliz_db.query(
  121. input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
  122. )
  123. mock_search.assert_called_with(
  124. collection_name=mock_config.collection_name,
  125. data=["query_text"],
  126. limit=1,
  127. output_fields=["*"],
  128. )
  129. assert query_result_with_citations == [
  130. ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
  131. ]
  132. @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
  133. @patch("embedchain.vectordb.zilliz.connections", autospec=True)
  134. def test_query_without_skip_embedding(self, mock_connect, mock_client, mock_embedder, mock_config):
  135. """
  136. Test if the `ZillizVectorDB` instance is takes in the query without skip_embeddings.
  137. """
  138. # Create an instance of ZillizVectorDB with mock config
  139. zilliz_db = ZillizVectorDB(config=mock_config)
  140. # Add a 'embedder' attribute to the ZillizVectorDB instance for testing
  141. zilliz_db.embedder = mock_embedder # Mock the 'collection' object
  142. # Add a 'collection' attribute to the ZillizVectorDB instance for testing
  143. zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
  144. assert zilliz_db.client == mock_client()
  145. # Mock the MilvusClient search method
  146. with patch.object(zilliz_db.client, "search") as mock_search:
  147. # Mock the embedding function
  148. mock_embedder.embedding_fn.return_value = ["query_vector"]
  149. # Mock the search result
  150. mock_search.return_value = [
  151. [
  152. {
  153. "distance": 0.0,
  154. "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
  155. }
  156. ]
  157. ]
  158. # Call the query method with skip_embedding=False
  159. query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
  160. # Assert that MilvusClient.search was called with the correct parameters
  161. mock_search.assert_called_with(
  162. collection_name=mock_config.collection_name,
  163. data=["query_vector"],
  164. limit=1,
  165. output_fields=["*"],
  166. )
  167. # Assert that the query result matches the expected result
  168. assert query_result == ["result_doc"]
  169. query_result_with_citations = zilliz_db.query(
  170. input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
  171. )
  172. mock_search.assert_called_with(
  173. collection_name=mock_config.collection_name,
  174. data=["query_vector"],
  175. limit=1,
  176. output_fields=["*"],
  177. )
  178. assert query_result_with_citations == [
  179. ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.0})
  180. ]