test_chroma_db.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from chromadb.config import Settings
  5. from embedchain import App
  6. from embedchain.config import AppConfig
  7. from embedchain.vectordb.chroma_db import ChromaDB, chromadb
  8. class TestChromaDbHosts(unittest.TestCase):
  9. def test_init_with_host_and_port(self):
  10. """
  11. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  12. """
  13. host = "test-host"
  14. port = "1234"
  15. with patch.object(chromadb, "HttpClient") as mock_client:
  16. _db = ChromaDB(host=host, port=port, embedding_fn=len)
  17. expected_settings = Settings(
  18. chroma_server_host="test-host",
  19. chroma_server_http_port="1234",
  20. )
  21. mock_client.assert_called_once_with(expected_settings)
  22. # Review this test
  23. class TestChromaDbHostsInit(unittest.TestCase):
  24. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  25. def test_init_with_host_and_port(self, mock_client):
  26. """
  27. Test if the `App` instance is initialized with the correct host and port values.
  28. """
  29. host = "test-host"
  30. port = "1234"
  31. config = AppConfig(host=host, port=port, collect_metrics=False)
  32. _app = App(config)
  33. # self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
  34. # self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
  35. class TestChromaDbHostsNone(unittest.TestCase):
  36. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  37. def test_init_with_host_and_port(self, mock_client):
  38. """
  39. Test if the `App` instance is initialized without default hosts and ports.
  40. """
  41. _app = App(config=AppConfig(collect_metrics=False))
  42. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  43. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  44. class TestChromaDbHostsLoglevel(unittest.TestCase):
  45. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  46. def test_init_with_host_and_port(self, mock_client):
  47. """
  48. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  49. """
  50. config = AppConfig(log_level="DEBUG")
  51. _app = App(config=AppConfig(collect_metrics=False))
  52. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  53. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  54. class TestChromaDbDuplicateHandling:
  55. def test_duplicates_throw_warning(self, caplog):
  56. """
  57. Test that add duplicates throws an error.
  58. """
  59. # Start with a clean app
  60. App().reset()
  61. app = App(config=AppConfig(collect_metrics=False))
  62. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  63. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  64. assert "Insert of existing embedding ID: 0" in caplog.text
  65. assert "Add of existing embedding ID: 0" in caplog.text
  66. def test_duplicates_collections_no_warning(self, caplog):
  67. """
  68. Test that different collections can have duplicates.
  69. """
  70. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  71. # Start with a clean app
  72. App().reset()
  73. app = App(config=AppConfig(collect_metrics=False))
  74. app.set_collection("test_collection_1")
  75. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  76. app.set_collection("test_collection_2")
  77. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  78. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  79. assert "Add of existing embedding ID: 0" not in caplog.text # not
  80. class TestChromaDbCollection(unittest.TestCase):
  81. def test_init_with_default_collection(self):
  82. """
  83. Test if the `App` instance is initialized with the correct default collection name.
  84. """
  85. app = App(config=AppConfig(collect_metrics=False))
  86. self.assertEqual(app.collection.name, "embedchain_store")
  87. def test_init_with_custom_collection(self):
  88. """
  89. Test if the `App` instance is initialized with the correct custom collection name.
  90. """
  91. config = AppConfig(collection_name="test_collection", collect_metrics=False)
  92. app = App(config)
  93. self.assertEqual(app.collection.name, "test_collection")
  94. def test_set_collection(self):
  95. """
  96. Test if the `App` collection is correctly switched using the `set_collection` method.
  97. """
  98. app = App(config=AppConfig(collect_metrics=False))
  99. app.set_collection("test_collection")
  100. self.assertEqual(app.collection.name, "test_collection")
  101. def test_changes_encapsulated(self):
  102. """
  103. Test that changes to one collection do not affect the other collection
  104. """
  105. # Start with a clean app
  106. App().reset()
  107. app = App(config=AppConfig(collect_metrics=False))
  108. app.set_collection("test_collection_1")
  109. # Collection should be empty when created
  110. self.assertEqual(app.count(), 0)
  111. app.collection.add(embeddings=[0, 0, 0], ids=["0"])
  112. # After adding, should contain one item
  113. self.assertEqual(app.count(), 1)
  114. app.set_collection("test_collection_2")
  115. # New collection is empty
  116. self.assertEqual(app.count(), 0)
  117. # Adding to new collection should not effect existing collection
  118. app.collection.add(embeddings=[0, 0, 0], ids=["0"])
  119. app.set_collection("test_collection_1")
  120. # Should still be 1, not 2.
  121. self.assertEqual(app.count(), 1)
  122. def test_collections_are_persistent(self):
  123. """
  124. Test that a collection can be picked up later.
  125. """
  126. # Start with a clean app
  127. App().reset()
  128. app = App(config=AppConfig(collect_metrics=False))
  129. app.set_collection("test_collection_1")
  130. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  131. del app
  132. app = App(config=AppConfig(collect_metrics=False))
  133. app.set_collection("test_collection_1")
  134. self.assertEqual(app.count(), 1)
  135. def test_parallel_collections(self):
  136. """
  137. Test that two apps can have different collections open in parallel.
  138. Switching the names will allow instant access to the collection of
  139. the other app.
  140. """
  141. # Start clean
  142. App().reset()
  143. # Create two apps
  144. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  145. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  146. # app2 has been created last, but adding to app1 will still write to collection 1.
  147. app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
  148. self.assertEqual(app1.count(), 1)
  149. self.assertEqual(app2.count(), 0)
  150. # Add data
  151. app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  152. app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
  153. # Swap names and test
  154. app1.set_collection("test_collection_2")
  155. self.assertEqual(app1.count(), 1)
  156. app2.set_collection("test_collection_1")
  157. self.assertEqual(app2.count(), 3)
  158. def test_ids_share_collections(self):
  159. """
  160. Different ids should still share collections.
  161. """
  162. # Start clean
  163. App().reset()
  164. # Create two apps
  165. app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
  166. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  167. # Add data
  168. app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  169. app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
  170. # Both should have the same collection
  171. self.assertEqual(app1.count(), 3)
  172. self.assertEqual(app2.count(), 3)
  173. def test_reset(self):
  174. """
  175. Resetting should hit all collections and ids.
  176. """
  177. # Start clean
  178. App().reset()
  179. # Create four apps.
  180. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  181. app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
  182. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  183. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
  184. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
  185. # Each one of them get data
  186. app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
  187. app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
  188. app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
  189. app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
  190. # Resetting the first one should reset them all.
  191. app1.reset()
  192. # Reinstantiate them
  193. app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
  194. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  195. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  196. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  197. # All should be empty
  198. self.assertEqual(app1.count(), 0)
  199. self.assertEqual(app2.count(), 0)
  200. self.assertEqual(app3.count(), 0)
  201. self.assertEqual(app4.count(), 0)