test_weaviate.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import unittest
  2. from unittest.mock import patch
  3. from embedchain import App
  4. from embedchain.config import AppConfig
  5. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.vectordb.weaviate import WeaviateDB
  8. class TestWeaviateDb(unittest.TestCase):
  9. def test_incorrect_config_throws_error(self):
  10. """Test the init method of the WeaviateDb class throws error for incorrect config"""
  11. with self.assertRaises(TypeError):
  12. WeaviateDB(config=PineconeDBConfig())
  13. @patch("embedchain.vectordb.weaviate.weaviate")
  14. def test_initialize(self, weaviate_mock):
  15. """Test the init method of the WeaviateDb class."""
  16. weaviate_client_mock = weaviate_mock.Client.return_value
  17. weaviate_client_schema_mock = weaviate_client_mock.schema
  18. # Mock that schema doesn't already exist so that a new schema is created
  19. weaviate_client_schema_mock.exists.return_value = False
  20. # Set the embedder
  21. embedder = BaseEmbedder()
  22. embedder.set_vector_dimension(1526)
  23. # Create a Weaviate instance
  24. db = WeaviateDB()
  25. app_config = AppConfig(collect_metrics=False)
  26. App(config=app_config, db=db, embedder=embedder)
  27. expected_class_obj = {
  28. "classes": [
  29. {
  30. "class": "Embedchain_store_1526",
  31. "vectorizer": "none",
  32. "properties": [
  33. {
  34. "name": "identifier",
  35. "dataType": ["text"],
  36. },
  37. {
  38. "name": "text",
  39. "dataType": ["text"],
  40. },
  41. {
  42. "name": "metadata",
  43. "dataType": ["Embedchain_store_1526_metadata"],
  44. },
  45. ],
  46. },
  47. {
  48. "class": "Embedchain_store_1526_metadata",
  49. "vectorizer": "none",
  50. "properties": [
  51. {
  52. "name": "data_type",
  53. "dataType": ["text"],
  54. },
  55. {
  56. "name": "doc_id",
  57. "dataType": ["text"],
  58. },
  59. {
  60. "name": "url",
  61. "dataType": ["text"],
  62. },
  63. {
  64. "name": "hash",
  65. "dataType": ["text"],
  66. },
  67. {
  68. "name": "app_id",
  69. "dataType": ["text"],
  70. },
  71. ],
  72. },
  73. ]
  74. }
  75. # Assert that the Weaviate client was initialized
  76. weaviate_mock.Client.assert_called_once()
  77. self.assertEqual(db.index_name, "Embedchain_store_1526")
  78. weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
  79. @patch("embedchain.vectordb.weaviate.weaviate")
  80. def test_get_or_create_db(self, weaviate_mock):
  81. """Test the _get_or_create_db method of the WeaviateDb class."""
  82. weaviate_client_mock = weaviate_mock.Client.return_value
  83. embedder = BaseEmbedder()
  84. embedder.set_vector_dimension(1526)
  85. # Create a Weaviate instance
  86. db = WeaviateDB()
  87. app_config = AppConfig(collect_metrics=False)
  88. App(config=app_config, db=db, embedder=embedder)
  89. expected_client = db._get_or_create_db()
  90. self.assertEqual(expected_client, weaviate_client_mock)
  91. @patch("embedchain.vectordb.weaviate.weaviate")
  92. def test_add(self, weaviate_mock):
  93. """Test the add method of the WeaviateDb class."""
  94. weaviate_client_mock = weaviate_mock.Client.return_value
  95. weaviate_client_batch_mock = weaviate_client_mock.batch
  96. weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
  97. # Set the embedder
  98. embedder = BaseEmbedder()
  99. embedder.set_vector_dimension(1526)
  100. # Create a Weaviate instance
  101. db = WeaviateDB()
  102. app_config = AppConfig(collect_metrics=False)
  103. App(config=app_config, db=db, embedder=embedder)
  104. db.BATCH_SIZE = 1
  105. embeddings = [[1, 2, 3], [4, 5, 6]]
  106. documents = ["This is a test document.", "This is another test document."]
  107. metadatas = [None, None]
  108. ids = ["123", "456"]
  109. skip_embedding = True
  110. db.add(embeddings, documents, metadatas, ids, skip_embedding)
  111. # Check if the document was added to the database.
  112. weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
  113. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  114. data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0]
  115. )
  116. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  117. data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1]
  118. )
  119. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  120. data_object={"identifier": ids[0], "text": documents[0]},
  121. class_name="Embedchain_store_1526",
  122. vector=embeddings[0],
  123. )
  124. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  125. data_object={"identifier": ids[1], "text": documents[1]},
  126. class_name="Embedchain_store_1526",
  127. vector=embeddings[1],
  128. )
  129. @patch("embedchain.vectordb.weaviate.weaviate")
  130. def test_query_without_where(self, weaviate_mock):
  131. """Test the query method of the WeaviateDb class."""
  132. weaviate_client_mock = weaviate_mock.Client.return_value
  133. weaviate_client_query_mock = weaviate_client_mock.query
  134. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  135. # Set the embedder
  136. embedder = BaseEmbedder()
  137. embedder.set_vector_dimension(1526)
  138. # Create a Weaviate instance
  139. db = WeaviateDB()
  140. app_config = AppConfig(collect_metrics=False)
  141. App(config=app_config, db=db, embedder=embedder)
  142. # Query for the document.
  143. db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
  144. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
  145. weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
  146. {"vector": ["This is a test document."]}
  147. )
  148. @patch("embedchain.vectordb.weaviate.weaviate")
  149. def test_query_with_where(self, weaviate_mock):
  150. """Test the query method of the WeaviateDb class."""
  151. weaviate_client_mock = weaviate_mock.Client.return_value
  152. weaviate_client_query_mock = weaviate_client_mock.query
  153. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  154. weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
  155. # Set the embedder
  156. embedder = BaseEmbedder()
  157. embedder.set_vector_dimension(1526)
  158. # Create a Weaviate instance
  159. db = WeaviateDB()
  160. app_config = AppConfig(collect_metrics=False)
  161. App(config=app_config, db=db, embedder=embedder)
  162. # Query for the document.
  163. db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
  164. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
  165. weaviate_client_query_get_mock.with_where.assert_called_once_with(
  166. {"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
  167. )
  168. weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
  169. {"vector": ["This is a test document."]}
  170. )
  171. @patch("embedchain.vectordb.weaviate.weaviate")
  172. def test_reset(self, weaviate_mock):
  173. """Test the reset method of the WeaviateDb class."""
  174. weaviate_client_mock = weaviate_mock.Client.return_value
  175. weaviate_client_batch_mock = weaviate_client_mock.batch
  176. # Set the embedder
  177. embedder = BaseEmbedder()
  178. embedder.set_vector_dimension(1526)
  179. # Create a Weaviate instance
  180. db = WeaviateDB()
  181. app_config = AppConfig(collect_metrics=False)
  182. App(config=app_config, db=db, embedder=embedder)
  183. # Reset the database.
  184. db.reset()
  185. weaviate_client_batch_mock.delete_objects.assert_called_once_with(
  186. "Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  187. )
  188. @patch("embedchain.vectordb.weaviate.weaviate")
  189. def test_count(self, weaviate_mock):
  190. """Test the reset method of the WeaviateDb class."""
  191. weaviate_client_mock = weaviate_mock.Client.return_value
  192. weaviate_client_query = weaviate_client_mock.query
  193. # Set the embedder
  194. embedder = BaseEmbedder()
  195. embedder.set_vector_dimension(1526)
  196. # Create a Weaviate instance
  197. db = WeaviateDB()
  198. app_config = AppConfig(collect_metrics=False)
  199. App(config=app_config, db=db, embedder=embedder)
  200. # Reset the database.
  201. db.count()
  202. weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")