test_weaviate.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import unittest
  2. from unittest.mock import patch
  3. from embedchain import App
  4. from embedchain.config import AppConfig
  5. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  6. from embedchain.embedder.base import BaseEmbedder
  7. from embedchain.vectordb.weaviate import WeaviateDB
  8. class TestWeaviateDb(unittest.TestCase):
  9. def test_incorrect_config_throws_error(self):
  10. """Test the init method of the WeaviateDb class throws error for incorrect config"""
  11. with self.assertRaises(TypeError):
  12. WeaviateDB(config=PineconeDBConfig())
  13. @patch("embedchain.vectordb.weaviate.weaviate")
  14. def test_initialize(self, weaviate_mock):
  15. """Test the init method of the WeaviateDb class."""
  16. weaviate_client_mock = weaviate_mock.Client.return_value
  17. weaviate_client_schema_mock = weaviate_client_mock.schema
  18. # Mock that schema doesn't already exist so that a new schema is created
  19. weaviate_client_schema_mock.exists.return_value = False
  20. # Set the embedder
  21. embedder = BaseEmbedder()
  22. embedder.set_vector_dimension(1526)
  23. # Create a Weaviate instance
  24. db = WeaviateDB()
  25. app_config = AppConfig(collect_metrics=False)
  26. App(config=app_config, db=db, embedder=embedder)
  27. expected_class_obj = {
  28. "classes": [
  29. {
  30. "class": "Embedchain_store_1526",
  31. "vectorizer": "none",
  32. "properties": [
  33. {
  34. "name": "identifier",
  35. "dataType": ["text"],
  36. },
  37. {
  38. "name": "text",
  39. "dataType": ["text"],
  40. },
  41. {
  42. "name": "metadata",
  43. "dataType": ["Embedchain_store_1526_metadata"],
  44. },
  45. ],
  46. },
  47. {
  48. "class": "Embedchain_store_1526_metadata",
  49. "vectorizer": "none",
  50. "properties": [
  51. {
  52. "name": "data_type",
  53. "dataType": ["text"],
  54. },
  55. {
  56. "name": "doc_id",
  57. "dataType": ["text"],
  58. },
  59. {
  60. "name": "url",
  61. "dataType": ["text"],
  62. },
  63. {
  64. "name": "hash",
  65. "dataType": ["text"],
  66. },
  67. {
  68. "name": "app_id",
  69. "dataType": ["text"],
  70. },
  71. {
  72. "name": "text",
  73. "dataType": ["text"],
  74. },
  75. ],
  76. },
  77. ]
  78. }
  79. # Assert that the Weaviate client was initialized
  80. weaviate_mock.Client.assert_called_once()
  81. self.assertEqual(db.index_name, "Embedchain_store_1526")
  82. weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
  83. @patch("embedchain.vectordb.weaviate.weaviate")
  84. def test_get_or_create_db(self, weaviate_mock):
  85. """Test the _get_or_create_db method of the WeaviateDb class."""
  86. weaviate_client_mock = weaviate_mock.Client.return_value
  87. embedder = BaseEmbedder()
  88. embedder.set_vector_dimension(1526)
  89. # Create a Weaviate instance
  90. db = WeaviateDB()
  91. app_config = AppConfig(collect_metrics=False)
  92. App(config=app_config, db=db, embedder=embedder)
  93. expected_client = db._get_or_create_db()
  94. self.assertEqual(expected_client, weaviate_client_mock)
  95. @patch("embedchain.vectordb.weaviate.weaviate")
  96. def test_add(self, weaviate_mock):
  97. """Test the add method of the WeaviateDb class."""
  98. weaviate_client_mock = weaviate_mock.Client.return_value
  99. weaviate_client_batch_mock = weaviate_client_mock.batch
  100. weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
  101. # Set the embedder
  102. embedder = BaseEmbedder()
  103. embedder.set_vector_dimension(1526)
  104. # Create a Weaviate instance
  105. db = WeaviateDB()
  106. app_config = AppConfig(collect_metrics=False)
  107. App(config=app_config, db=db, embedder=embedder)
  108. db.BATCH_SIZE = 1
  109. embeddings = [[1, 2, 3], [4, 5, 6]]
  110. documents = ["This is a test document.", "This is another test document."]
  111. metadatas = [None, None]
  112. ids = ["123", "456"]
  113. skip_embedding = True
  114. db.add(embeddings, documents, metadatas, ids, skip_embedding)
  115. # Check if the document was added to the database.
  116. weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
  117. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  118. data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0]
  119. )
  120. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  121. data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1]
  122. )
  123. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  124. data_object={"identifier": ids[0], "text": documents[0]},
  125. class_name="Embedchain_store_1526",
  126. vector=embeddings[0],
  127. )
  128. weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
  129. data_object={"identifier": ids[1], "text": documents[1]},
  130. class_name="Embedchain_store_1526",
  131. vector=embeddings[1],
  132. )
  133. @patch("embedchain.vectordb.weaviate.weaviate")
  134. def test_query_without_where(self, weaviate_mock):
  135. """Test the query method of the WeaviateDb class."""
  136. weaviate_client_mock = weaviate_mock.Client.return_value
  137. weaviate_client_query_mock = weaviate_client_mock.query
  138. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  139. # Set the embedder
  140. embedder = BaseEmbedder()
  141. embedder.set_vector_dimension(1526)
  142. # Create a Weaviate instance
  143. db = WeaviateDB()
  144. app_config = AppConfig(collect_metrics=False)
  145. App(config=app_config, db=db, embedder=embedder)
  146. # Query for the document.
  147. db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
  148. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
  149. weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
  150. {"vector": ["This is a test document."]}
  151. )
  152. @patch("embedchain.vectordb.weaviate.weaviate")
  153. def test_query_with_where(self, weaviate_mock):
  154. """Test the query method of the WeaviateDb class."""
  155. weaviate_client_mock = weaviate_mock.Client.return_value
  156. weaviate_client_query_mock = weaviate_client_mock.query
  157. weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
  158. weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
  159. # Set the embedder
  160. embedder = BaseEmbedder()
  161. embedder.set_vector_dimension(1526)
  162. # Create a Weaviate instance
  163. db = WeaviateDB()
  164. app_config = AppConfig(collect_metrics=False)
  165. App(config=app_config, db=db, embedder=embedder)
  166. # Query for the document.
  167. db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
  168. weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
  169. weaviate_client_query_get_mock.with_where.assert_called_once_with(
  170. {"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
  171. )
  172. weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
  173. {"vector": ["This is a test document."]}
  174. )
  175. @patch("embedchain.vectordb.weaviate.weaviate")
  176. def test_reset(self, weaviate_mock):
  177. """Test the reset method of the WeaviateDb class."""
  178. weaviate_client_mock = weaviate_mock.Client.return_value
  179. weaviate_client_batch_mock = weaviate_client_mock.batch
  180. # Set the embedder
  181. embedder = BaseEmbedder()
  182. embedder.set_vector_dimension(1526)
  183. # Create a Weaviate instance
  184. db = WeaviateDB()
  185. app_config = AppConfig(collect_metrics=False)
  186. App(config=app_config, db=db, embedder=embedder)
  187. # Reset the database.
  188. db.reset()
  189. weaviate_client_batch_mock.delete_objects.assert_called_once_with(
  190. "Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  191. )
  192. @patch("embedchain.vectordb.weaviate.weaviate")
  193. def test_count(self, weaviate_mock):
  194. """Test the reset method of the WeaviateDb class."""
  195. weaviate_client_mock = weaviate_mock.Client.return_value
  196. weaviate_client_query = weaviate_client_mock.query
  197. # Set the embedder
  198. embedder = BaseEmbedder()
  199. embedder.set_vector_dimension(1526)
  200. # Create a Weaviate instance
  201. db = WeaviateDB()
  202. app_config = AppConfig(collect_metrics=False)
  203. App(config=app_config, db=db, embedder=embedder)
  204. # Reset the database.
  205. db.count()
  206. weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")