test_chroma_db.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from chromadb.config import Settings
  5. from embedchain import App
  6. from embedchain.config import AppConfig, ChromaDbConfig
  7. from embedchain.vectordb.chroma import ChromaDB
  8. class TestChromaDbHosts(unittest.TestCase):
  9. def test_init_with_host_and_port(self):
  10. """
  11. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  12. """
  13. host = "test-host"
  14. port = "1234"
  15. config = ChromaDbConfig(host=host, port=port)
  16. db = ChromaDB(config=config)
  17. settings = db.client.get_settings()
  18. self.assertEqual(settings.chroma_server_host, host)
  19. self.assertEqual(settings.chroma_server_http_port, port)
  20. def test_init_with_basic_auth(self):
  21. host = "test-host"
  22. port = "1234"
  23. chroma_auth_settings = {
  24. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  25. "chroma_client_auth_credentials": "admin:admin",
  26. }
  27. config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
  28. db = ChromaDB(config=config)
  29. settings = db.client.get_settings()
  30. self.assertEqual(settings.chroma_server_host, host)
  31. self.assertEqual(settings.chroma_server_http_port, port)
  32. self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
  33. self.assertEqual(
  34. settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
  35. )
  36. # Review this test
  37. class TestChromaDbHostsInit(unittest.TestCase):
  38. @patch("embedchain.vectordb.chroma.chromadb.Client")
  39. def test_app_init_with_host_and_port(self, mock_client):
  40. """
  41. Test if the `App` instance is initialized with the correct host and port values.
  42. """
  43. host = "test-host"
  44. port = "1234"
  45. config = AppConfig(collect_metrics=False)
  46. chromadb_config = ChromaDbConfig(host=host, port=port)
  47. _app = App(config, chromadb_config=chromadb_config)
  48. called_settings: Settings = mock_client.call_args[0][0]
  49. self.assertEqual(called_settings.chroma_server_host, host)
  50. self.assertEqual(called_settings.chroma_server_http_port, port)
  51. class TestChromaDbHostsNone(unittest.TestCase):
  52. @patch("embedchain.vectordb.chroma.chromadb.Client")
  53. def test_init_with_host_and_port_none(self, mock_client):
  54. """
  55. Test if the `App` instance is initialized without default hosts and ports.
  56. """
  57. _app = App(config=AppConfig(collect_metrics=False))
  58. called_settings: Settings = mock_client.call_args[0][0]
  59. self.assertEqual(called_settings.chroma_server_host, None)
  60. self.assertEqual(called_settings.chroma_server_http_port, None)
  61. class TestChromaDbHostsLoglevel(unittest.TestCase):
  62. @patch("embedchain.vectordb.chroma.chromadb.Client")
  63. def test_init_with_host_and_port_log_level(self, mock_client):
  64. """
  65. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  66. """
  67. _app = App(config=AppConfig(collect_metrics=False))
  68. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  69. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  70. class TestChromaDbDuplicateHandling:
  71. chroma_settings = {"allow_reset": True}
  72. chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
  73. app_config = AppConfig(collection_name=False, collect_metrics=False)
  74. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  75. def test_duplicates_throw_warning(self, caplog):
  76. """
  77. Test that add duplicates throws an error.
  78. """
  79. # Start with a clean app
  80. self.app_with_settings.reset()
  81. app = App(config=AppConfig(collect_metrics=False))
  82. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  83. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  84. assert "Insert of existing embedding ID: 0" in caplog.text
  85. assert "Add of existing embedding ID: 0" in caplog.text
  86. def test_duplicates_collections_no_warning(self, caplog):
  87. """
  88. Test that different collections can have duplicates.
  89. """
  90. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  91. # Start with a clean app
  92. self.app_with_settings.reset()
  93. app = App(config=AppConfig(collect_metrics=False))
  94. app.set_collection_name("test_collection_1")
  95. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  96. app.set_collection_name("test_collection_2")
  97. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  98. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  99. assert "Add of existing embedding ID: 0" not in caplog.text # not
  100. class TestChromaDbCollection(unittest.TestCase):
  101. chroma_settings = {"allow_reset": True}
  102. chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
  103. app_config = AppConfig(collection_name=False, collect_metrics=False)
  104. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  105. def test_init_with_default_collection(self):
  106. """
  107. Test if the `App` instance is initialized with the correct default collection name.
  108. """
  109. app = App(config=AppConfig(collect_metrics=False))
  110. self.assertEqual(app.db.collection.name, "embedchain_store")
  111. def test_init_with_custom_collection(self):
  112. """
  113. Test if the `App` instance is initialized with the correct custom collection name.
  114. """
  115. config = AppConfig(collect_metrics=False)
  116. app = App(config=config)
  117. app.set_collection_name(name="test_collection")
  118. self.assertEqual(app.db.collection.name, "test_collection")
  119. def test_set_collection_name(self):
  120. """
  121. Test if the `App` collection is correctly switched using the `set_collection_name` method.
  122. """
  123. app = App(config=AppConfig(collect_metrics=False))
  124. app.set_collection_name("test_collection")
  125. self.assertEqual(app.db.collection.name, "test_collection")
  126. def test_changes_encapsulated(self):
  127. """
  128. Test that changes to one collection do not affect the other collection
  129. """
  130. # Start with a clean app
  131. self.app_with_settings.reset()
  132. app = App(config=AppConfig(collect_metrics=False))
  133. app.set_collection_name("test_collection_1")
  134. # Collection should be empty when created
  135. self.assertEqual(app.count(), 0)
  136. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  137. # After adding, should contain one item
  138. self.assertEqual(app.count(), 1)
  139. app.set_collection_name("test_collection_2")
  140. # New collection is empty
  141. self.assertEqual(app.count(), 0)
  142. # Adding to new collection should not effect existing collection
  143. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  144. app.set_collection_name("test_collection_1")
  145. # Should still be 1, not 2.
  146. self.assertEqual(app.count(), 1)
  147. def test_collections_are_persistent(self):
  148. """
  149. Test that a collection can be picked up later.
  150. """
  151. # Start with a clean app
  152. self.app_with_settings.reset()
  153. app = App(config=AppConfig(collect_metrics=False))
  154. app.set_collection_name("test_collection_1")
  155. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  156. del app
  157. app = App(config=AppConfig(collect_metrics=False))
  158. app.set_collection_name("test_collection_1")
  159. self.assertEqual(app.count(), 1)
  160. def test_parallel_collections(self):
  161. """
  162. Test that two apps can have different collections open in parallel.
  163. Switching the names will allow instant access to the collection of
  164. the other app.
  165. """
  166. # Start clean
  167. self.app_with_settings.reset()
  168. # Create two apps
  169. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  170. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  171. # app2 has been created last, but adding to app1 will still write to collection 1.
  172. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  173. self.assertEqual(app1.db.count(), 1)
  174. self.assertEqual(app2.db.count(), 0)
  175. # Add data
  176. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  177. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  178. # Swap names and test
  179. app1.set_collection_name("test_collection_2")
  180. self.assertEqual(app1.count(), 1)
  181. app2.set_collection_name("test_collection_1")
  182. self.assertEqual(app2.count(), 3)
  183. def test_ids_share_collections(self):
  184. """
  185. Different ids should still share collections.
  186. """
  187. # Start clean
  188. self.app_with_settings.reset()
  189. # Create two apps
  190. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  191. app1.set_collection_name("one_collection")
  192. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  193. app2.set_collection_name("one_collection")
  194. # Add data
  195. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  196. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  197. # Both should have the same collection
  198. self.assertEqual(app1.count(), 3)
  199. self.assertEqual(app2.count(), 3)
  200. def test_reset(self):
  201. """
  202. Resetting should hit all collections and ids.
  203. """
  204. # Start clean
  205. self.app_with_settings.reset()
  206. # Create four apps.
  207. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  208. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
  209. app1.set_collection_name("one_collection")
  210. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  211. app2.set_collection_name("one_collection")
  212. app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  213. app3.set_collection_name("three_collection")
  214. app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
  215. app4.set_collection_name("four_collection")
  216. # Each one of them get data
  217. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  218. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  219. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  220. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  221. # Resetting the first one should reset them all.
  222. app1.reset()
  223. # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
  224. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  225. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  226. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  227. # All should be empty
  228. self.assertEqual(app1.count(), 0)
  229. self.assertEqual(app2.count(), 0)
  230. self.assertEqual(app3.count(), 0)
  231. self.assertEqual(app4.count(), 0)