test_chroma_db.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from chromadb.config import Settings
  5. from embedchain import App
  6. from embedchain.config import AppConfig, ChromaDbConfig
  7. from embedchain.models import EmbeddingFunctions, Providers
  8. from embedchain.vectordb.chroma_db import ChromaDB
  9. class TestChromaDbHosts(unittest.TestCase):
  10. def test_init_with_host_and_port(self):
  11. """
  12. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  13. """
  14. host = "test-host"
  15. port = "1234"
  16. config = ChromaDbConfig(host=host, port=port)
  17. db = ChromaDB(config=config)
  18. settings = db.client.get_settings()
  19. self.assertEqual(settings.chroma_server_host, host)
  20. self.assertEqual(settings.chroma_server_http_port, port)
  21. def test_init_with_basic_auth(self):
  22. host = "test-host"
  23. port = "1234"
  24. chroma_auth_settings = {
  25. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  26. "chroma_client_auth_credentials": "admin:admin",
  27. }
  28. config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
  29. db = ChromaDB(config=config)
  30. settings = db.client.get_settings()
  31. self.assertEqual(settings.chroma_server_host, host)
  32. self.assertEqual(settings.chroma_server_http_port, port)
  33. self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
  34. self.assertEqual(
  35. settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
  36. )
  37. # Review this test
  38. class TestChromaDbHostsInit(unittest.TestCase):
  39. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  40. def test_app_init_with_host_and_port(self, mock_client):
  41. """
  42. Test if the `App` instance is initialized with the correct host and port values.
  43. """
  44. host = "test-host"
  45. port = "1234"
  46. config = AppConfig(collect_metrics=False)
  47. chromadb_config = ChromaDbConfig(host=host, port=port)
  48. _app = App(config, chromadb_config=chromadb_config)
  49. called_settings: Settings = mock_client.call_args[0][0]
  50. self.assertEqual(called_settings.chroma_server_host, host)
  51. self.assertEqual(called_settings.chroma_server_http_port, port)
  52. class TestChromaDbHostsNone(unittest.TestCase):
  53. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  54. def test_init_with_host_and_port_none(self, mock_client):
  55. """
  56. Test if the `App` instance is initialized without default hosts and ports.
  57. """
  58. _app = App(config=AppConfig(collect_metrics=False))
  59. called_settings: Settings = mock_client.call_args[0][0]
  60. self.assertEqual(called_settings.chroma_server_host, None)
  61. self.assertEqual(called_settings.chroma_server_http_port, None)
  62. class TestChromaDbHostsLoglevel(unittest.TestCase):
  63. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  64. def test_init_with_host_and_port_log_level(self, mock_client):
  65. """
  66. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  67. """
  68. config = AppConfig(log_level="DEBUG")
  69. _app = App(config=AppConfig(collect_metrics=False))
  70. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  71. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  72. class TestChromaDbDuplicateHandling:
  73. chroma_settings = {"allow_reset": True}
  74. chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
  75. app_config = AppConfig(collection_name=False, collect_metrics=False)
  76. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  77. def test_duplicates_throw_warning(self, caplog):
  78. """
  79. Test that add duplicates throws an error.
  80. """
  81. # Start with a clean app
  82. self.app_with_settings.reset()
  83. app = App(config=AppConfig(collect_metrics=False))
  84. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  85. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  86. assert "Insert of existing embedding ID: 0" in caplog.text
  87. assert "Add of existing embedding ID: 0" in caplog.text
  88. def test_duplicates_collections_no_warning(self, caplog):
  89. """
  90. Test that different collections can have duplicates.
  91. """
  92. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  93. # Start with a clean app
  94. self.app_with_settings.reset()
  95. app = App(config=AppConfig(collect_metrics=False))
  96. app.set_collection("test_collection_1")
  97. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  98. app.set_collection("test_collection_2")
  99. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  100. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  101. assert "Add of existing embedding ID: 0" not in caplog.text # not
  102. class TestChromaDbCollection(unittest.TestCase):
  103. chroma_settings = {"allow_reset": True}
  104. chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
  105. app_config = AppConfig(collection_name=False, collect_metrics=False)
  106. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  107. def test_init_with_default_collection(self):
  108. """
  109. Test if the `App` instance is initialized with the correct default collection name.
  110. """
  111. app = App(config=AppConfig(collect_metrics=False))
  112. self.assertEqual(app.db.collection.name, "embedchain_store")
  113. def test_init_with_custom_collection(self):
  114. """
  115. Test if the `App` instance is initialized with the correct custom collection name.
  116. """
  117. config = AppConfig(collect_metrics=False)
  118. app = App(config=config)
  119. app.set_collection(collection_name="test_collection")
  120. self.assertEqual(app.db.collection.name, "test_collection")
  121. def test_set_collection(self):
  122. """
  123. Test if the `App` collection is correctly switched using the `set_collection` method.
  124. """
  125. app = App(config=AppConfig(collect_metrics=False))
  126. app.set_collection("test_collection")
  127. self.assertEqual(app.db.collection.name, "test_collection")
  128. def test_changes_encapsulated(self):
  129. """
  130. Test that changes to one collection do not affect the other collection
  131. """
  132. # Start with a clean app
  133. self.app_with_settings.reset()
  134. app = App(config=AppConfig(collect_metrics=False))
  135. app.set_collection("test_collection_1")
  136. # Collection should be empty when created
  137. self.assertEqual(app.count(), 0)
  138. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  139. # After adding, should contain one item
  140. self.assertEqual(app.count(), 1)
  141. app.set_collection("test_collection_2")
  142. # New collection is empty
  143. self.assertEqual(app.count(), 0)
  144. # Adding to new collection should not effect existing collection
  145. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  146. app.set_collection("test_collection_1")
  147. # Should still be 1, not 2.
  148. self.assertEqual(app.count(), 1)
  149. def test_collections_are_persistent(self):
  150. """
  151. Test that a collection can be picked up later.
  152. """
  153. # Start with a clean app
  154. self.app_with_settings.reset()
  155. app = App(config=AppConfig(collect_metrics=False))
  156. app.set_collection("test_collection_1")
  157. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  158. del app
  159. app = App(config=AppConfig(collect_metrics=False))
  160. app.set_collection("test_collection_1")
  161. self.assertEqual(app.count(), 1)
  162. def test_parallel_collections(self):
  163. """
  164. Test that two apps can have different collections open in parallel.
  165. Switching the names will allow instant access to the collection of
  166. the other app.
  167. """
  168. # Start clean
  169. self.app_with_settings.reset()
  170. # Create two apps
  171. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  172. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  173. # app2 has been created last, but adding to app1 will still write to collection 1.
  174. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  175. self.assertEqual(app1.db.count(), 1)
  176. self.assertEqual(app2.db.count(), 0)
  177. # Add data
  178. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  179. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  180. # Swap names and test
  181. app1.set_collection("test_collection_2")
  182. self.assertEqual(app1.count(), 1)
  183. app2.set_collection("test_collection_1")
  184. self.assertEqual(app2.count(), 3)
  185. def test_ids_share_collections(self):
  186. """
  187. Different ids should still share collections.
  188. """
  189. # Start clean
  190. self.app_with_settings.reset()
  191. # Create two apps
  192. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  193. app1.set_collection("one_collection")
  194. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  195. app2.set_collection("one_collection")
  196. # Add data
  197. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  198. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  199. # Both should have the same collection
  200. self.assertEqual(app1.count(), 3)
  201. self.assertEqual(app2.count(), 3)
  202. def test_reset(self):
  203. """
  204. Resetting should hit all collections and ids.
  205. """
  206. # Start clean
  207. self.app_with_settings.reset()
  208. # Create four apps.
  209. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  210. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
  211. app1.set_collection("one_collection")
  212. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  213. app2.set_collection("one_collection")
  214. app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  215. app3.set_collection("three_collection")
  216. app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
  217. app4.set_collection("four_collection")
  218. # Each one of them get data
  219. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  220. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  221. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  222. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  223. # Resetting the first one should reset them all.
  224. app1.reset()
  225. # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
  226. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  227. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  228. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  229. # All should be empty
  230. self.assertEqual(app1.count(), 0)
  231. self.assertEqual(app2.count(), 0)
  232. self.assertEqual(app3.count(), 0)
  233. self.assertEqual(app4.count(), 0)