test_weaviate.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import unittest
  2. from unittest.mock import patch
  3. from embedchain import App
  4. from embedchain.config import AppConfig
  5. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.vectordb.weaviate import WeaviateDB
  8. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  9. """A mock embedding function."""
  10. return [[1, 2, 3], [4, 5, 6]]
  11. class TestWeaviateDb(unittest.TestCase):
  12. def test_incorrect_config_throws_error(self):
  13. """Test the init method of the WeaviateDb class throws error for incorrect config"""
  14. with self.assertRaises(TypeError):
  15. WeaviateDB(config=PineconeDBConfig())
  16. @patch("embedchain.vectordb.weaviate.weaviate")
  17. def test_initialize(self, weaviate_mock):
  18. """Test the init method of the WeaviateDb class."""
  19. weaviate_client_mock = weaviate_mock.Client.return_value
  20. weaviate_client_schema_mock = weaviate_client_mock.schema
  21. # Mock that schema doesn't already exist so that a new schema is created
  22. weaviate_client_schema_mock.exists.return_value = False
  23. # Set the embedder
  24. embedder = BaseEmbedder()
  25. embedder.set_vector_dimension(1536)
  26. embedder.set_embedding_fn(mock_embedding_fn)
  27. # Create a Weaviate instance
  28. db = WeaviateDB()
  29. app_config = AppConfig(collect_metrics=False)
  30. App(config=app_config, db=db, embedding_model=embedder)
  31. expected_class_obj = {
  32. "classes": [
  33. {
  34. "class": "Embedchain_store_1536",
  35. "vectorizer": "none",
  36. "properties": [
  37. {
  38. "name": "identifier",
  39. "dataType": ["text"],
  40. },
  41. {
  42. "name": "text",
  43. "dataType": ["text"],
  44. },
  45. {
  46. "name": "metadata",
  47. "dataType": ["Embedchain_store_1536_metadata"],
  48. },
  49. ],
  50. },
  51. {
  52. "class": "Embedchain_store_1536_metadata",
  53. "vectorizer": "none",
  54. "properties": [
  55. {
  56. "name": "data_type",
  57. "dataType": ["text"],
  58. },
  59. {
  60. "name": "doc_id",
  61. "dataType": ["text"],
  62. },
  63. {
  64. "name": "url",
  65. "dataType": ["text"],
  66. },
  67. {
  68. "name": "hash",
  69. "dataType": ["text"],
  70. },
  71. {
  72. "name": "app_id",
  73. "dataType": ["text"],
  74. },
  75. ],
  76. },
  77. ]
  78. }
  79. # Assert that the Weaviate client was initialized
  80. weaviate_mock.Client.assert_called_once()
  81. self.assertEqual(db.index_name, "Embedchain_store_1536")
  82. weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
  83. @patch("embedchain.vectordb.weaviate.weaviate")
  84. def test_get_or_create_db(self, weaviate_mock):
  85. """Test the _get_or_create_db method of the WeaviateDb class."""
  86. weaviate_client_mock = weaviate_mock.Client.return_value
  87. embedder = BaseEmbedder()
  88. embedder.set_vector_dimension(1536)
  89. embedder.set_embedding_fn(mock_embedding_fn)
  90. # Create a Weaviate instance
  91. db = WeaviateDB()
  92. app_config = AppConfig(collect_metrics=False)
  93. App(config=app_config, db=db, embedding_model=embedder)
  94. expected_client = db._get_or_create_db()
  95. self.assertEqual(expected_client, weaviate_client_mock)
  96. @patch("embedchain.vectordb.weaviate.weaviate")
  97. def test_add(self, weaviate_mock):
  98. """Test the add method of the WeaviateDb class."""
  99. weaviate_client_mock = weaviate_mock.Client.return_value
  100. weaviate_client_batch_mock = weaviate_client_mock.batch
  101. weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
  102. # Set the embedder
  103. embedder = BaseEmbedder()
  104. embedder.set_vector_dimension(1536)
  105. embedder.set_embedding_fn(mock_embedding_fn)
  106. # Create a Weaviate instance
  107. db = WeaviateDB()
  108. app_config = AppConfig(collect_metrics=False)
  109. App(config=app_config, db=db, embedding_model=embedder)
  110. db.BATCH_SIZE = 1
  111. documents = ["This is test document"]
  112. metadatas = [None]
  113. ids = ["id_1"]
  114. db.add(documents, metadatas, ids)
  115. # Check if the document was added to the database.
  116. weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
  117. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  118. data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
  119. )
  120. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  121. data_object={"text": documents[0]},
  122. class_name="Embedchain_store_1536_metadata",
  123. vector=[1, 2, 3],
  124. )
  125. @patch("embedchain.vectordb.weaviate.weaviate")
  126. def test_query_without_where(self, weaviate_mock):
  127. """Test the query method of the WeaviateDb class."""
  128. weaviate_client_mock = weaviate_mock.Client.return_value
  129. weaviate_client_query_mock = weaviate_client_mock.query
  130. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  131. # Set the embedder
  132. embedder = BaseEmbedder()
  133. embedder.set_vector_dimension(1536)
  134. embedder.set_embedding_fn(mock_embedding_fn)
  135. # Create a Weaviate instance
  136. db = WeaviateDB()
  137. app_config = AppConfig(collect_metrics=False)
  138. App(config=app_config, db=db, embedding_model=embedder)
  139. # Query for the document.
  140. db.query(input_query="This is a test document.", n_results=1, where={})
  141. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
  142. weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  143. @patch("embedchain.vectordb.weaviate.weaviate")
  144. def test_query_with_where(self, weaviate_mock):
  145. """Test the query method of the WeaviateDb class."""
  146. weaviate_client_mock = weaviate_mock.Client.return_value
  147. weaviate_client_query_mock = weaviate_client_mock.query
  148. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  149. weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
  150. # Set the embedder
  151. embedder = BaseEmbedder()
  152. embedder.set_vector_dimension(1536)
  153. embedder.set_embedding_fn(mock_embedding_fn)
  154. # Create a Weaviate instance
  155. db = WeaviateDB()
  156. app_config = AppConfig(collect_metrics=False)
  157. App(config=app_config, db=db, embedding_model=embedder)
  158. # Query for the document.
  159. db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
  160. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
  161. weaviate_client_query_get_mock.with_where.assert_called_once_with(
  162. {"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
  163. )
  164. weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  165. @patch("embedchain.vectordb.weaviate.weaviate")
  166. def test_reset(self, weaviate_mock):
  167. """Test the reset method of the WeaviateDb class."""
  168. weaviate_client_mock = weaviate_mock.Client.return_value
  169. weaviate_client_batch_mock = weaviate_client_mock.batch
  170. # Set the embedder
  171. embedder = BaseEmbedder()
  172. embedder.set_vector_dimension(1536)
  173. embedder.set_embedding_fn(mock_embedding_fn)
  174. # Create a Weaviate instance
  175. db = WeaviateDB()
  176. app_config = AppConfig(collect_metrics=False)
  177. App(config=app_config, db=db, embedding_model=embedder)
  178. # Reset the database.
  179. db.reset()
  180. weaviate_client_batch_mock.delete_objects.assert_called_once_with(
  181. "Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  182. )
  183. @patch("embedchain.vectordb.weaviate.weaviate")
  184. def test_count(self, weaviate_mock):
  185. """Test the reset method of the WeaviateDb class."""
  186. weaviate_client_mock = weaviate_mock.Client.return_value
  187. weaviate_client_query = weaviate_client_mock.query
  188. # Set the embedder
  189. embedder = BaseEmbedder()
  190. embedder.set_vector_dimension(1536)
  191. embedder.set_embedding_fn(mock_embedding_fn)
  192. # Create a Weaviate instance
  193. db = WeaviateDB()
  194. app_config = AppConfig(collect_metrics=False)
  195. App(config=app_config, db=db, embedding_model=embedder)
  196. # Reset the database.
  197. db.count()
  198. weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")