test_weaviate.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import unittest
  2. from unittest.mock import patch
  3. from embedchain import App
  4. from embedchain.config import AppConfig
  5. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.vectordb.weaviate import WeaviateDB
  8. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  9. """A mock embedding function."""
  10. return [[1, 2, 3], [4, 5, 6]]
  11. class TestWeaviateDb(unittest.TestCase):
  12. def test_incorrect_config_throws_error(self):
  13. """Test the init method of the WeaviateDb class throws error for incorrect config"""
  14. with self.assertRaises(TypeError):
  15. WeaviateDB(config=PineconeDBConfig())
  16. @patch("embedchain.vectordb.weaviate.weaviate")
  17. def test_initialize(self, weaviate_mock):
  18. """Test the init method of the WeaviateDb class."""
  19. weaviate_client_mock = weaviate_mock.Client.return_value
  20. weaviate_client_schema_mock = weaviate_client_mock.schema
  21. # Mock that schema doesn't already exist so that a new schema is created
  22. weaviate_client_schema_mock.exists.return_value = False
  23. # Set the embedder
  24. embedder = BaseEmbedder()
  25. embedder.set_vector_dimension(1526)
  26. embedder.set_embedding_fn(mock_embedding_fn)
  27. # Create a Weaviate instance
  28. db = WeaviateDB()
  29. app_config = AppConfig(collect_metrics=False)
  30. App(config=app_config, db=db, embedding_model=embedder)
  31. expected_class_obj = {
  32. "classes": [
  33. {
  34. "class": "Embedchain_store_1526",
  35. "vectorizer": "none",
  36. "properties": [
  37. {
  38. "name": "identifier",
  39. "dataType": ["text"],
  40. },
  41. {
  42. "name": "text",
  43. "dataType": ["text"],
  44. },
  45. {
  46. "name": "metadata",
  47. "dataType": ["Embedchain_store_1526_metadata"],
  48. },
  49. ],
  50. },
  51. {
  52. "class": "Embedchain_store_1526_metadata",
  53. "vectorizer": "none",
  54. "properties": [
  55. {
  56. "name": "data_type",
  57. "dataType": ["text"],
  58. },
  59. {
  60. "name": "doc_id",
  61. "dataType": ["text"],
  62. },
  63. {
  64. "name": "url",
  65. "dataType": ["text"],
  66. },
  67. {
  68. "name": "hash",
  69. "dataType": ["text"],
  70. },
  71. {
  72. "name": "app_id",
  73. "dataType": ["text"],
  74. },
  75. ],
  76. },
  77. ]
  78. }
  79. # Assert that the Weaviate client was initialized
  80. weaviate_mock.Client.assert_called_once()
  81. self.assertEqual(db.index_name, "Embedchain_store_1526")
  82. weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
  83. @patch("embedchain.vectordb.weaviate.weaviate")
  84. def test_get_or_create_db(self, weaviate_mock):
  85. """Test the _get_or_create_db method of the WeaviateDb class."""
  86. weaviate_client_mock = weaviate_mock.Client.return_value
  87. embedder = BaseEmbedder()
  88. embedder.set_vector_dimension(1526)
  89. embedder.set_embedding_fn(mock_embedding_fn)
  90. # Create a Weaviate instance
  91. db = WeaviateDB()
  92. app_config = AppConfig(collect_metrics=False)
  93. App(config=app_config, db=db, embedding_model=embedder)
  94. expected_client = db._get_or_create_db()
  95. self.assertEqual(expected_client, weaviate_client_mock)
  96. @patch("embedchain.vectordb.weaviate.weaviate")
  97. def test_add(self, weaviate_mock):
  98. """Test the add method of the WeaviateDb class."""
  99. weaviate_client_mock = weaviate_mock.Client.return_value
  100. weaviate_client_batch_mock = weaviate_client_mock.batch
  101. weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
  102. # Set the embedder
  103. embedder = BaseEmbedder()
  104. embedder.set_vector_dimension(1526)
  105. embedder.set_embedding_fn(mock_embedding_fn)
  106. # Create a Weaviate instance
  107. db = WeaviateDB()
  108. app_config = AppConfig(collect_metrics=False)
  109. App(config=app_config, db=db, embedding_model=embedder)
  110. db.BATCH_SIZE = 1
  111. embeddings = [[1, 2, 3], [4, 5, 6]]
  112. documents = ["This is a test document.", "This is another test document."]
  113. metadatas = [None, None]
  114. ids = ["123", "456"]
  115. db.add(embeddings, documents, metadatas, ids)
  116. # Check if the document was added to the database.
  117. weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
  118. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  119. data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0]
  120. )
  121. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  122. data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1]
  123. )
  124. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  125. data_object={"identifier": ids[0], "text": documents[0]},
  126. class_name="Embedchain_store_1526",
  127. vector=embeddings[0],
  128. )
  129. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  130. data_object={"identifier": ids[1], "text": documents[1]},
  131. class_name="Embedchain_store_1526",
  132. vector=embeddings[1],
  133. )
  134. @patch("embedchain.vectordb.weaviate.weaviate")
  135. def test_query_without_where(self, weaviate_mock):
  136. """Test the query method of the WeaviateDb class."""
  137. weaviate_client_mock = weaviate_mock.Client.return_value
  138. weaviate_client_query_mock = weaviate_client_mock.query
  139. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  140. # Set the embedder
  141. embedder = BaseEmbedder()
  142. embedder.set_vector_dimension(1526)
  143. embedder.set_embedding_fn(mock_embedding_fn)
  144. # Create a Weaviate instance
  145. db = WeaviateDB()
  146. app_config = AppConfig(collect_metrics=False)
  147. App(config=app_config, db=db, embedding_model=embedder)
  148. # Query for the document.
  149. db.query(input_query=["This is a test document."], n_results=1, where={})
  150. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
  151. weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  152. @patch("embedchain.vectordb.weaviate.weaviate")
  153. def test_query_with_where(self, weaviate_mock):
  154. """Test the query method of the WeaviateDb class."""
  155. weaviate_client_mock = weaviate_mock.Client.return_value
  156. weaviate_client_query_mock = weaviate_client_mock.query
  157. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  158. weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
  159. # Set the embedder
  160. embedder = BaseEmbedder()
  161. embedder.set_vector_dimension(1526)
  162. embedder.set_embedding_fn(mock_embedding_fn)
  163. # Create a Weaviate instance
  164. db = WeaviateDB()
  165. app_config = AppConfig(collect_metrics=False)
  166. App(config=app_config, db=db, embedding_model=embedder)
  167. # Query for the document.
  168. db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
  169. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
  170. weaviate_client_query_get_mock.with_where.assert_called_once_with(
  171. {"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
  172. )
  173. weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  174. @patch("embedchain.vectordb.weaviate.weaviate")
  175. def test_reset(self, weaviate_mock):
  176. """Test the reset method of the WeaviateDb class."""
  177. weaviate_client_mock = weaviate_mock.Client.return_value
  178. weaviate_client_batch_mock = weaviate_client_mock.batch
  179. # Set the embedder
  180. embedder = BaseEmbedder()
  181. embedder.set_vector_dimension(1526)
  182. embedder.set_embedding_fn(mock_embedding_fn)
  183. # Create a Weaviate instance
  184. db = WeaviateDB()
  185. app_config = AppConfig(collect_metrics=False)
  186. App(config=app_config, db=db, embedding_model=embedder)
  187. # Reset the database.
  188. db.reset()
  189. weaviate_client_batch_mock.delete_objects.assert_called_once_with(
  190. "Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  191. )
  192. @patch("embedchain.vectordb.weaviate.weaviate")
  193. def test_count(self, weaviate_mock):
  194. """Test the reset method of the WeaviateDb class."""
  195. weaviate_client_mock = weaviate_mock.Client.return_value
  196. weaviate_client_query = weaviate_client_mock.query
  197. # Set the embedder
  198. embedder = BaseEmbedder()
  199. embedder.set_vector_dimension(1526)
  200. embedder.set_embedding_fn(mock_embedding_fn)
  201. # Create a Weaviate instance
  202. db = WeaviateDB()
  203. app_config = AppConfig(collect_metrics=False)
  204. App(config=app_config, db=db, embedding_model=embedder)
  205. # Reset the database.
  206. db.count()
  207. weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")