test_weaviate.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import unittest
  2. from unittest.mock import patch
  3. from embedchain import App
  4. from embedchain.config import AppConfig
  5. from embedchain.config.vector_db.pinecone import PineconeDBConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.vectordb.weaviate import WeaviateDB
  8. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  9. """A mock embedding function."""
  10. return [[1, 2, 3], [4, 5, 6]]
  11. class TestWeaviateDb(unittest.TestCase):
  12. def test_incorrect_config_throws_error(self):
  13. """Test the init method of the WeaviateDb class throws error for incorrect config"""
  14. with self.assertRaises(TypeError):
  15. WeaviateDB(config=PineconeDBConfig())
  16. @patch("embedchain.vectordb.weaviate.weaviate")
  17. def test_initialize(self, weaviate_mock):
  18. """Test the init method of the WeaviateDb class."""
  19. weaviate_client_mock = weaviate_mock.Client.return_value
  20. weaviate_client_schema_mock = weaviate_client_mock.schema
  21. # Mock that schema doesn't already exist so that a new schema is created
  22. weaviate_client_schema_mock.exists.return_value = False
  23. # Set the embedder
  24. embedder = BaseEmbedder()
  25. embedder.set_vector_dimension(1536)
  26. embedder.set_embedding_fn(mock_embedding_fn)
  27. # Create a Weaviate instance
  28. db = WeaviateDB()
  29. app_config = AppConfig(collect_metrics=False)
  30. App(config=app_config, db=db, embedding_model=embedder)
  31. expected_class_obj = {
  32. "classes": [
  33. {
  34. "class": "Embedchain_store_1536",
  35. "vectorizer": "none",
  36. "properties": [
  37. {
  38. "name": "identifier",
  39. "dataType": ["text"],
  40. },
  41. {
  42. "name": "text",
  43. "dataType": ["text"],
  44. },
  45. {
  46. "name": "metadata",
  47. "dataType": ["Embedchain_store_1536_metadata"],
  48. },
  49. ],
  50. },
  51. {
  52. "class": "Embedchain_store_1536_metadata",
  53. "vectorizer": "none",
  54. "properties": [
  55. {
  56. "name": "data_type",
  57. "dataType": ["text"],
  58. },
  59. {
  60. "name": "doc_id",
  61. "dataType": ["text"],
  62. },
  63. {
  64. "name": "url",
  65. "dataType": ["text"],
  66. },
  67. {
  68. "name": "hash",
  69. "dataType": ["text"],
  70. },
  71. {
  72. "name": "app_id",
  73. "dataType": ["text"],
  74. },
  75. ],
  76. },
  77. ]
  78. }
  79. # Assert that the Weaviate client was initialized
  80. weaviate_mock.Client.assert_called_once()
  81. self.assertEqual(db.index_name, "Embedchain_store_1536")
  82. weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
  83. @patch("embedchain.vectordb.weaviate.weaviate")
  84. def test_get_or_create_db(self, weaviate_mock):
  85. """Test the _get_or_create_db method of the WeaviateDb class."""
  86. weaviate_client_mock = weaviate_mock.Client.return_value
  87. embedder = BaseEmbedder()
  88. embedder.set_vector_dimension(1536)
  89. embedder.set_embedding_fn(mock_embedding_fn)
  90. # Create a Weaviate instance
  91. db = WeaviateDB()
  92. app_config = AppConfig(collect_metrics=False)
  93. App(config=app_config, db=db, embedding_model=embedder)
  94. expected_client = db._get_or_create_db()
  95. self.assertEqual(expected_client, weaviate_client_mock)
  96. @patch("embedchain.vectordb.weaviate.weaviate")
  97. def test_add(self, weaviate_mock):
  98. """Test the add method of the WeaviateDb class."""
  99. weaviate_client_mock = weaviate_mock.Client.return_value
  100. weaviate_client_batch_mock = weaviate_client_mock.batch
  101. weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
  102. # Set the embedder
  103. embedder = BaseEmbedder()
  104. embedder.set_vector_dimension(1536)
  105. embedder.set_embedding_fn(mock_embedding_fn)
  106. # Create a Weaviate instance
  107. db = WeaviateDB()
  108. app_config = AppConfig(collect_metrics=False)
  109. App(config=app_config, db=db, embedding_model=embedder)
  110. documents = ["This is test document"]
  111. metadatas = [None]
  112. ids = ["id_1"]
  113. db.add(documents, metadatas, ids)
  114. # Check if the document was added to the database.
  115. weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3)
  116. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  117. data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
  118. )
  119. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  120. data_object={"text": documents[0]},
  121. class_name="Embedchain_store_1536_metadata",
  122. vector=[1, 2, 3],
  123. )
  124. @patch("embedchain.vectordb.weaviate.weaviate")
  125. def test_query_without_where(self, weaviate_mock):
  126. """Test the query method of the WeaviateDb class."""
  127. weaviate_client_mock = weaviate_mock.Client.return_value
  128. weaviate_client_query_mock = weaviate_client_mock.query
  129. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  130. # Set the embedder
  131. embedder = BaseEmbedder()
  132. embedder.set_vector_dimension(1536)
  133. embedder.set_embedding_fn(mock_embedding_fn)
  134. # Create a Weaviate instance
  135. db = WeaviateDB()
  136. app_config = AppConfig(collect_metrics=False)
  137. App(config=app_config, db=db, embedding_model=embedder)
  138. # Query for the document.
  139. db.query(input_query="This is a test document.", n_results=1, where={})
  140. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
  141. weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  142. @patch("embedchain.vectordb.weaviate.weaviate")
  143. def test_query_with_where(self, weaviate_mock):
  144. """Test the query method of the WeaviateDb class."""
  145. weaviate_client_mock = weaviate_mock.Client.return_value
  146. weaviate_client_query_mock = weaviate_client_mock.query
  147. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  148. weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
  149. # Set the embedder
  150. embedder = BaseEmbedder()
  151. embedder.set_vector_dimension(1536)
  152. embedder.set_embedding_fn(mock_embedding_fn)
  153. # Create a Weaviate instance
  154. db = WeaviateDB()
  155. app_config = AppConfig(collect_metrics=False)
  156. App(config=app_config, db=db, embedding_model=embedder)
  157. # Query for the document.
  158. db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
  159. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
  160. weaviate_client_query_get_mock.with_where.assert_called_once_with(
  161. {"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
  162. )
  163. weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  164. @patch("embedchain.vectordb.weaviate.weaviate")
  165. def test_reset(self, weaviate_mock):
  166. """Test the reset method of the WeaviateDb class."""
  167. weaviate_client_mock = weaviate_mock.Client.return_value
  168. weaviate_client_batch_mock = weaviate_client_mock.batch
  169. # Set the embedder
  170. embedder = BaseEmbedder()
  171. embedder.set_vector_dimension(1536)
  172. embedder.set_embedding_fn(mock_embedding_fn)
  173. # Create a Weaviate instance
  174. db = WeaviateDB()
  175. app_config = AppConfig(collect_metrics=False)
  176. App(config=app_config, db=db, embedding_model=embedder)
  177. # Reset the database.
  178. db.reset()
  179. weaviate_client_batch_mock.delete_objects.assert_called_once_with(
  180. "Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  181. )
  182. @patch("embedchain.vectordb.weaviate.weaviate")
  183. def test_count(self, weaviate_mock):
  184. """Test the reset method of the WeaviateDb class."""
  185. weaviate_client_mock = weaviate_mock.Client.return_value
  186. weaviate_client_query = weaviate_client_mock.query
  187. # Set the embedder
  188. embedder = BaseEmbedder()
  189. embedder.set_vector_dimension(1536)
  190. embedder.set_embedding_fn(mock_embedding_fn)
  191. # Create a Weaviate instance
  192. db = WeaviateDB()
  193. app_config = AppConfig(collect_metrics=False)
  194. App(config=app_config, db=db, embedding_model=embedder)
  195. # Reset the database.
  196. db.count()
  197. weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")