test_chroma_db.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from chromadb.config import Settings
  5. from embedchain import App
  6. from embedchain.config import AppConfig, ChromaDbConfig
  7. from embedchain.vectordb.chroma_db import ChromaDB
  8. class TestChromaDbHosts(unittest.TestCase):
  9. def test_init_with_host_and_port(self):
  10. """
  11. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  12. """
  13. host = "test-host"
  14. port = "1234"
  15. config = ChromaDbConfig(host=host, port=port)
  16. db = ChromaDB(config=config)
  17. settings = db.client.get_settings()
  18. self.assertEqual(settings.chroma_server_host, host)
  19. self.assertEqual(settings.chroma_server_http_port, port)
  20. def test_init_with_basic_auth(self):
  21. host = "test-host"
  22. port = "1234"
  23. chroma_auth_settings = {
  24. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  25. "chroma_client_auth_credentials": "admin:admin",
  26. }
  27. config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
  28. db = ChromaDB(config=config)
  29. settings = db.client.get_settings()
  30. self.assertEqual(settings.chroma_server_host, host)
  31. self.assertEqual(settings.chroma_server_http_port, port)
  32. self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
  33. self.assertEqual(
  34. settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
  35. )
  36. # Review this test
  37. class TestChromaDbHostsInit(unittest.TestCase):
  38. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  39. def test_app_init_with_host_and_port(self, mock_client):
  40. """
  41. Test if the `App` instance is initialized with the correct host and port values.
  42. """
  43. host = "test-host"
  44. port = "1234"
  45. config = AppConfig(collect_metrics=False)
  46. chromadb_config = ChromaDbConfig(host=host, port=port)
  47. _app = App(config, chromadb_config=chromadb_config)
  48. called_settings: Settings = mock_client.call_args[0][0]
  49. self.assertEqual(called_settings.chroma_server_host, host)
  50. self.assertEqual(called_settings.chroma_server_http_port, port)
  51. class TestChromaDbHostsNone(unittest.TestCase):
  52. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  53. def test_init_with_host_and_port_none(self, mock_client):
  54. """
  55. Test if the `App` instance is initialized without default hosts and ports.
  56. """
  57. _app = App(config=AppConfig(collect_metrics=False))
  58. called_settings: Settings = mock_client.call_args[0][0]
  59. self.assertEqual(called_settings.chroma_server_host, None)
  60. self.assertEqual(called_settings.chroma_server_http_port, None)
  61. class TestChromaDbHostsLoglevel(unittest.TestCase):
  62. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  63. def test_init_with_host_and_port_log_level(self, mock_client):
  64. """
  65. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  66. """
  67. _app = App(config=AppConfig(collect_metrics=False))
  68. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  69. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  70. class TestChromaDbDuplicateHandling:
  71. chroma_settings = {"allow_reset": True}
  72. chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
  73. app_config = AppConfig(collection_name=False, collect_metrics=False)
  74. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  75. def test_duplicates_throw_warning(self, caplog):
  76. """
  77. Test that add duplicates throws an error.
  78. """
  79. # Start with a clean app
  80. self.app_with_settings.reset()
  81. app = App(config=AppConfig(collect_metrics=False))
  82. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  83. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  84. assert "Insert of existing embedding ID: 0" in caplog.text
  85. assert "Add of existing embedding ID: 0" in caplog.text
  86. def test_duplicates_collections_no_warning(self, caplog):
  87. """
  88. Test that different collections can have duplicates.
  89. """
  90. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  91. # Start with a clean app
  92. self.app_with_settings.reset()
  93. app = App(config=AppConfig(collect_metrics=False))
  94. app.set_collection_name("test_collection_1")
  95. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  96. app.set_collection_name("test_collection_2")
  97. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  98. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  99. assert "Add of existing embedding ID: 0" not in caplog.text # not
  100. class TestChromaDbCollection(unittest.TestCase):
  101. chroma_settings = {"allow_reset": True}
  102. chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
  103. app_config = AppConfig(collection_name=False, collect_metrics=False)
  104. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  105. def test_init_with_default_collection(self):
  106. """
  107. Test if the `App` instance is initialized with the correct default collection name.
  108. """
  109. app = App(config=AppConfig(collect_metrics=False))
  110. self.assertEqual(app.db.collection.name, "embedchain_store")
  111. def test_init_with_custom_collection(self):
  112. """
  113. Test if the `App` instance is initialized with the correct custom collection name.
  114. """
  115. config = AppConfig(collect_metrics=False)
  116. app = App(config=config)
  117. app.set_collection_name(name="test_collection")
  118. self.assertEqual(app.db.collection.name, "test_collection")
  119. def test_set_collection_name(self):
  120. """
  121. Test if the `App` collection is correctly switched using the `set_collection_name` method.
  122. """
  123. app = App(config=AppConfig(collect_metrics=False))
  124. app.set_collection_name("test_collection")
  125. self.assertEqual(app.db.collection.name, "test_collection")
  126. def test_changes_encapsulated(self):
  127. """
  128. Test that changes to one collection do not affect the other collection
  129. """
  130. # Start with a clean app
  131. self.app_with_settings.reset()
  132. app = App(config=AppConfig(collect_metrics=False))
  133. app.set_collection_name("test_collection_1")
  134. # Collection should be empty when created
  135. self.assertEqual(app.count(), 0)
  136. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  137. # After adding, should contain one item
  138. self.assertEqual(app.count(), 1)
  139. app.set_collection_name("test_collection_2")
  140. # New collection is empty
  141. self.assertEqual(app.count(), 0)
  142. # Adding to new collection should not effect existing collection
  143. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  144. app.set_collection_name("test_collection_1")
  145. # Should still be 1, not 2.
  146. self.assertEqual(app.count(), 1)
  147. def test_collections_are_persistent(self):
  148. """
  149. Test that a collection can be picked up later.
  150. """
  151. # Start with a clean app
  152. self.app_with_settings.reset()
  153. app = App(config=AppConfig(collect_metrics=False))
  154. app.set_collection_name("test_collection_1")
  155. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  156. del app
  157. app = App(config=AppConfig(collect_metrics=False))
  158. app.set_collection_name("test_collection_1")
  159. self.assertEqual(app.count(), 1)
  160. def test_parallel_collections(self):
  161. """
  162. Test that two apps can have different collections open in parallel.
  163. Switching the names will allow instant access to the collection of
  164. the other app.
  165. """
  166. # Start clean
  167. self.app_with_settings.reset()
  168. # Create two apps
  169. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  170. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  171. # app2 has been created last, but adding to app1 will still write to collection 1.
  172. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  173. self.assertEqual(app1.db.count(), 1)
  174. self.assertEqual(app2.db.count(), 0)
  175. # Add data
  176. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  177. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  178. # Swap names and test
  179. app1.set_collection_name("test_collection_2")
  180. self.assertEqual(app1.count(), 1)
  181. app2.set_collection_name("test_collection_1")
  182. self.assertEqual(app2.count(), 3)
  183. def test_ids_share_collections(self):
  184. """
  185. Different ids should still share collections.
  186. """
  187. # Start clean
  188. self.app_with_settings.reset()
  189. # Create two apps
  190. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  191. app1.set_collection_name("one_collection")
  192. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  193. app2.set_collection_name("one_collection")
  194. # Add data
  195. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  196. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  197. # Both should have the same collection
  198. self.assertEqual(app1.count(), 3)
  199. self.assertEqual(app2.count(), 3)
  200. def test_reset(self):
  201. """
  202. Resetting should hit all collections and ids.
  203. """
  204. # Start clean
  205. self.app_with_settings.reset()
  206. # Create four apps.
  207. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  208. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
  209. app1.set_collection_name("one_collection")
  210. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  211. app2.set_collection_name("one_collection")
  212. app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  213. app3.set_collection_name("three_collection")
  214. app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
  215. app4.set_collection_name("four_collection")
  216. # Each one of them get data
  217. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  218. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  219. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  220. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  221. # Resetting the first one should reset them all.
  222. app1.reset()
  223. # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
  224. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  225. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  226. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  227. # All should be empty
  228. self.assertEqual(app1.count(), 0)
  229. self.assertEqual(app2.count(), 0)
  230. self.assertEqual(app3.count(), 0)
  231. self.assertEqual(app4.count(), 0)