test_weaviate.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import unittest
  2. from unittest.mock import patch
  3. from embedchain import App
  4. from embedchain.config import AppConfig
  5. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.vectordb.weaviate import WeaviateDB
  8. def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
  9. """A mock embedding function."""
  10. return [[1, 2, 3], [4, 5, 6]]
  11. class TestWeaviateDb(unittest.TestCase):
  12. def test_incorrect_config_throws_error(self):
  13. """Test the init method of the WeaviateDb class throws error for incorrect config"""
  14. with self.assertRaises(TypeError):
  15. WeaviateDB(config=PineconeDBConfig())
  16. @patch("embedchain.vectordb.weaviate.weaviate")
  17. def test_initialize(self, weaviate_mock):
  18. """Test the init method of the WeaviateDb class."""
  19. weaviate_client_mock = weaviate_mock.Client.return_value
  20. weaviate_client_schema_mock = weaviate_client_mock.schema
  21. # Mock that schema doesn't already exist so that a new schema is created
  22. weaviate_client_schema_mock.exists.return_value = False
  23. # Set the embedder
  24. embedder = BaseEmbedder()
  25. embedder.set_vector_dimension(1536)
  26. embedder.set_embedding_fn(mock_embedding_fn)
  27. # Create a Weaviate instance
  28. db = WeaviateDB()
  29. app_config = AppConfig(collect_metrics=False)
  30. App(config=app_config, db=db, embedding_model=embedder)
  31. expected_class_obj = {
  32. "classes": [
  33. {
  34. "class": "Embedchain_store_1536",
  35. "vectorizer": "none",
  36. "properties": [
  37. {
  38. "name": "identifier",
  39. "dataType": ["text"],
  40. },
  41. {
  42. "name": "text",
  43. "dataType": ["text"],
  44. },
  45. {
  46. "name": "metadata",
  47. "dataType": ["Embedchain_store_1536_metadata"],
  48. },
  49. ],
  50. },
  51. {
  52. "class": "Embedchain_store_1536_metadata",
  53. "vectorizer": "none",
  54. "properties": [
  55. {
  56. "name": "data_type",
  57. "dataType": ["text"],
  58. },
  59. {
  60. "name": "doc_id",
  61. "dataType": ["text"],
  62. },
  63. {
  64. "name": "url",
  65. "dataType": ["text"],
  66. },
  67. {
  68. "name": "hash",
  69. "dataType": ["text"],
  70. },
  71. {
  72. "name": "app_id",
  73. "dataType": ["text"],
  74. },
  75. ],
  76. },
  77. ]
  78. }
  79. # Assert that the Weaviate client was initialized
  80. weaviate_mock.Client.assert_called_once()
  81. self.assertEqual(db.index_name, "Embedchain_store_1536")
  82. weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
  83. @patch("embedchain.vectordb.weaviate.weaviate")
  84. def test_get_or_create_db(self, weaviate_mock):
  85. """Test the _get_or_create_db method of the WeaviateDb class."""
  86. weaviate_client_mock = weaviate_mock.Client.return_value
  87. embedder = BaseEmbedder()
  88. embedder.set_vector_dimension(1536)
  89. embedder.set_embedding_fn(mock_embedding_fn)
  90. # Create a Weaviate instance
  91. db = WeaviateDB()
  92. app_config = AppConfig(collect_metrics=False)
  93. App(config=app_config, db=db, embedding_model=embedder)
  94. expected_client = db._get_or_create_db()
  95. self.assertEqual(expected_client, weaviate_client_mock)
  96. @patch("embedchain.vectordb.weaviate.weaviate")
  97. def test_add(self, weaviate_mock):
  98. """Test the add method of the WeaviateDb class."""
  99. weaviate_client_mock = weaviate_mock.Client.return_value
  100. weaviate_client_batch_mock = weaviate_client_mock.batch
  101. weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
  102. # Set the embedder
  103. embedder = BaseEmbedder()
  104. embedder.set_vector_dimension(1536)
  105. embedder.set_embedding_fn(mock_embedding_fn)
  106. # Create a Weaviate instance
  107. db = WeaviateDB()
  108. app_config = AppConfig(collect_metrics=False)
  109. App(config=app_config, db=db, embedding_model=embedder)
  110. documents = ["This is test document"]
  111. metadatas = [None]
  112. ids = ["id_1"]
  113. db.add(documents, metadatas, ids)
  114. # Check if the document was added to the database.
  115. weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3)
  116. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  117. data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
  118. )
  119. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  120. data_object={"text": documents[0]},
  121. class_name="Embedchain_store_1536_metadata",
  122. vector=[1, 2, 3],
  123. )
  124. @patch("embedchain.vectordb.weaviate.weaviate")
  125. def test_query_without_where(self, weaviate_mock):
  126. """Test the query method of the WeaviateDb class."""
  127. weaviate_client_mock = weaviate_mock.Client.return_value
  128. weaviate_client_query_mock = weaviate_client_mock.query
  129. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  130. # Set the embedder
  131. embedder = BaseEmbedder()
  132. embedder.set_vector_dimension(1536)
  133. embedder.set_embedding_fn(mock_embedding_fn)
  134. # Create a Weaviate instance
  135. db = WeaviateDB()
  136. app_config = AppConfig(collect_metrics=False)
  137. App(config=app_config, db=db, embedding_model=embedder)
  138. # Query for the document.
  139. db.query(input_query="This is a test document.", n_results=1, where={})
  140. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
  141. weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  142. @patch("embedchain.vectordb.weaviate.weaviate")
  143. def test_query_with_where(self, weaviate_mock):
  144. """Test the query method of the WeaviateDb class."""
  145. weaviate_client_mock = weaviate_mock.Client.return_value
  146. weaviate_client_query_mock = weaviate_client_mock.query
  147. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  148. weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
  149. # Set the embedder
  150. embedder = BaseEmbedder()
  151. embedder.set_vector_dimension(1536)
  152. embedder.set_embedding_fn(mock_embedding_fn)
  153. # Create a Weaviate instance
  154. db = WeaviateDB()
  155. app_config = AppConfig(collect_metrics=False)
  156. App(config=app_config, db=db, embedding_model=embedder)
  157. # Query for the document.
  158. db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
  159. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
  160. weaviate_client_query_get_mock.with_where.assert_called_once_with(
  161. {"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
  162. )
  163. weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
  164. @patch("embedchain.vectordb.weaviate.weaviate")
  165. def test_reset(self, weaviate_mock):
  166. """Test the reset method of the WeaviateDb class."""
  167. weaviate_client_mock = weaviate_mock.Client.return_value
  168. weaviate_client_batch_mock = weaviate_client_mock.batch
  169. # Set the embedder
  170. embedder = BaseEmbedder()
  171. embedder.set_vector_dimension(1536)
  172. embedder.set_embedding_fn(mock_embedding_fn)
  173. # Create a Weaviate instance
  174. db = WeaviateDB()
  175. app_config = AppConfig(collect_metrics=False)
  176. App(config=app_config, db=db, embedding_model=embedder)
  177. # Reset the database.
  178. db.reset()
  179. weaviate_client_batch_mock.delete_objects.assert_called_once_with(
  180. "Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  181. )
  182. @patch("embedchain.vectordb.weaviate.weaviate")
  183. def test_count(self, weaviate_mock):
  184. """Test the reset method of the WeaviateDb class."""
  185. weaviate_client_mock = weaviate_mock.Client.return_value
  186. weaviate_client_query = weaviate_client_mock.query
  187. # Set the embedder
  188. embedder = BaseEmbedder()
  189. embedder.set_vector_dimension(1536)
  190. embedder.set_embedding_fn(mock_embedding_fn)
  191. # Create a Weaviate instance
  192. db = WeaviateDB()
  193. app_config = AppConfig(collect_metrics=False)
  194. App(config=app_config, db=db, embedding_model=embedder)
  195. # Reset the database.
  196. db.count()
  197. weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")