test_zilliz_db.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # ruff: noqa: E501
  2. import os
  3. from unittest import mock
  4. from unittest.mock import Mock, patch
  5. import pytest
  6. from embedchain.config import ZillizDBConfig
  7. from embedchain.vectordb.zilliz import ZillizVectorDB
  8. # to run tests, provide the URI and TOKEN in .env file
  9. class TestZillizVectorDBConfig:
  10. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  11. def test_init_with_uri_and_token(self):
  12. """
  13. Test if the `ZillizVectorDBConfig` instance is initialized with the correct uri and token values.
  14. """
  15. # Create a ZillizDBConfig instance with mocked values
  16. expected_uri = "mocked_uri"
  17. expected_token = "mocked_token"
  18. db_config = ZillizDBConfig()
  19. # Assert that the values in the ZillizVectorDB instance match the mocked values
  20. assert db_config.uri == expected_uri
  21. assert db_config.token == expected_token
  22. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  23. def test_init_without_uri(self):
  24. """
  25. Test if the `ZillizVectorDBConfig` instance throws an error when no URI found.
  26. """
  27. try:
  28. del os.environ["ZILLIZ_CLOUD_URI"]
  29. except KeyError:
  30. pass
  31. with pytest.raises(AttributeError):
  32. ZillizDBConfig()
  33. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  34. def test_init_without_token(self):
  35. """
  36. Test if the `ZillizVectorDBConfig` instance throws an error when no Token found.
  37. """
  38. try:
  39. del os.environ["ZILLIZ_CLOUD_TOKEN"]
  40. except KeyError:
  41. pass
  42. # Test if an exception is raised when ZILLIZ_CLOUD_TOKEN is missing
  43. with pytest.raises(AttributeError):
  44. ZillizDBConfig()
  45. class TestZillizVectorDB:
  46. @pytest.fixture
  47. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  48. def mock_config(self, mocker):
  49. return mocker.Mock(spec=ZillizDBConfig())
  50. @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
  51. @patch("embedchain.vectordb.zilliz.connections.connect", autospec=True)
  52. def test_zilliz_vector_db_setup(self, mock_connect, mock_client, mock_config):
  53. """
  54. Test if the `ZillizVectorDB` instance is initialized with the correct uri and token values.
  55. """
  56. # Create an instance of ZillizVectorDB with the mock config
  57. # zilliz_db = ZillizVectorDB(config=mock_config)
  58. ZillizVectorDB(config=mock_config)
  59. # Assert that the MilvusClient and connections.connect were called
  60. mock_client.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
  61. mock_connect.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
  62. class TestZillizDBCollection:
  63. @pytest.fixture
  64. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  65. def mock_config(self, mocker):
  66. return mocker.Mock(spec=ZillizDBConfig())
  67. @pytest.fixture
  68. def mock_embedder(self, mocker):
  69. return mocker.Mock()
  70. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  71. def test_init_with_default_collection(self):
  72. """
  73. Test if the `ZillizVectorDB` instance is initialized with the correct default collection name.
  74. """
  75. # Create a ZillizDBConfig instance
  76. db_config = ZillizDBConfig()
  77. assert db_config.collection_name == "embedchain_store"
  78. @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
  79. def test_init_with_custom_collection(self):
  80. """
  81. Test if the `ZillizVectorDB` instance is initialized with the correct custom collection name.
  82. """
  83. # Create a ZillizDBConfig instance with mocked values
  84. expected_collection = "test_collection"
  85. db_config = ZillizDBConfig(collection_name="test_collection")
  86. assert db_config.collection_name == expected_collection
  87. @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
  88. @patch("embedchain.vectordb.zilliz.connections", autospec=True)
  89. def test_query_with_skip_embedding(self, mock_connect, mock_client, mock_config):
  90. """
  91. Test if the `ZillizVectorDB` instance is takes in the query with skip_embeddings.
  92. """
  93. # Create an instance of ZillizVectorDB with mock config
  94. zilliz_db = ZillizVectorDB(config=mock_config)
  95. # Add a 'collection' attribute to the ZillizVectorDB instance for testing
  96. zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
  97. assert zilliz_db.client == mock_client()
  98. # Mock the MilvusClient search method
  99. with patch.object(zilliz_db.client, "search") as mock_search:
  100. # Mock the search result
  101. mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
  102. # Call the query method with skip_embedding=True
  103. query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
  104. # Assert that MilvusClient.search was called with the correct parameters
  105. mock_search.assert_called_with(
  106. collection_name=mock_config.collection_name,
  107. data=["query_text"],
  108. limit=1,
  109. output_fields=["text", "url", "doc_id"],
  110. )
  111. # Assert that the query result matches the expected result
  112. assert query_result == ["result_doc"]
  113. query_result_with_citations = zilliz_db.query(
  114. input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
  115. )
  116. mock_search.assert_called_with(
  117. collection_name=mock_config.collection_name,
  118. data=["query_text"],
  119. limit=1,
  120. output_fields=["text", "url", "doc_id"],
  121. )
  122. assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
  123. @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
  124. @patch("embedchain.vectordb.zilliz.connections", autospec=True)
  125. def test_query_without_skip_embedding(self, mock_connect, mock_client, mock_embedder, mock_config):
  126. """
  127. Test if the `ZillizVectorDB` instance is takes in the query without skip_embeddings.
  128. """
  129. # Create an instance of ZillizVectorDB with mock config
  130. zilliz_db = ZillizVectorDB(config=mock_config)
  131. # Add a 'embedder' attribute to the ZillizVectorDB instance for testing
  132. zilliz_db.embedder = mock_embedder # Mock the 'collection' object
  133. # Add a 'collection' attribute to the ZillizVectorDB instance for testing
  134. zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
  135. assert zilliz_db.client == mock_client()
  136. # Mock the MilvusClient search method
  137. with patch.object(zilliz_db.client, "search") as mock_search:
  138. # Mock the embedding function
  139. mock_embedder.embedding_fn.return_value = ["query_vector"]
  140. # Mock the search result
  141. mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
  142. # Call the query method with skip_embedding=False
  143. query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
  144. # Assert that MilvusClient.search was called with the correct parameters
  145. mock_search.assert_called_with(
  146. collection_name=mock_config.collection_name,
  147. data=["query_vector"],
  148. limit=1,
  149. output_fields=["text", "url", "doc_id"],
  150. )
  151. # Assert that the query result matches the expected result
  152. assert query_result == ["result_doc"]
  153. query_result_with_citations = zilliz_db.query(
  154. input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
  155. )
  156. mock_search.assert_called_with(
  157. collection_name=mock_config.collection_name,
  158. data=["query_vector"],
  159. limit=1,
  160. output_fields=["text", "url", "doc_id"],
  161. )
  162. assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]