test_chroma_db.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from chromadb.config import Settings
  5. from embedchain import App
  6. from embedchain.config import AppConfig, ChromaDbConfig
  7. from embedchain.vectordb.chroma import ChromaDB
  8. class TestChromaDbHosts(unittest.TestCase):
  9. def test_init_with_host_and_port(self):
  10. """
  11. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  12. """
  13. host = "test-host"
  14. port = "1234"
  15. config = ChromaDbConfig(host=host, port=port)
  16. db = ChromaDB(config=config)
  17. settings = db.client.get_settings()
  18. self.assertEqual(settings.chroma_server_host, host)
  19. self.assertEqual(settings.chroma_server_http_port, port)
  20. def test_init_with_basic_auth(self):
  21. host = "test-host"
  22. port = "1234"
  23. chroma_config = {
  24. "host": host,
  25. "port": port,
  26. "chroma_settings": {
  27. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  28. "chroma_client_auth_credentials": "admin:admin",
  29. },
  30. }
  31. config = ChromaDbConfig(**chroma_config)
  32. db = ChromaDB(config=config)
  33. settings = db.client.get_settings()
  34. self.assertEqual(settings.chroma_server_host, host)
  35. self.assertEqual(settings.chroma_server_http_port, port)
  36. self.assertEqual(
  37. settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"]
  38. )
  39. self.assertEqual(
  40. settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
  41. )
  42. # Review this test
  43. class TestChromaDbHostsInit(unittest.TestCase):
  44. @patch("embedchain.vectordb.chroma.chromadb.Client")
  45. def test_app_init_with_host_and_port(self, mock_client):
  46. """
  47. Test if the `App` instance is initialized with the correct host and port values.
  48. """
  49. host = "test-host"
  50. port = "1234"
  51. config = AppConfig(collect_metrics=False)
  52. db_config = ChromaDbConfig(host=host, port=port)
  53. _app = App(config, db_config=db_config)
  54. called_settings: Settings = mock_client.call_args[0][0]
  55. self.assertEqual(called_settings.chroma_server_host, host)
  56. self.assertEqual(called_settings.chroma_server_http_port, port)
  57. class TestChromaDbHostsNone(unittest.TestCase):
  58. @patch("embedchain.vectordb.chroma.chromadb.Client")
  59. def test_init_with_host_and_port_none(self, mock_client):
  60. """
  61. Test if the `App` instance is initialized without default hosts and ports.
  62. """
  63. _app = App(config=AppConfig(collect_metrics=False))
  64. called_settings: Settings = mock_client.call_args[0][0]
  65. self.assertEqual(called_settings.chroma_server_host, None)
  66. self.assertEqual(called_settings.chroma_server_http_port, None)
  67. class TestChromaDbHostsLoglevel(unittest.TestCase):
  68. @patch("embedchain.vectordb.chroma.chromadb.Client")
  69. def test_init_with_host_and_port_log_level(self, mock_client):
  70. """
  71. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  72. """
  73. _app = App(config=AppConfig(collect_metrics=False))
  74. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  75. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  76. class TestChromaDbDuplicateHandling:
  77. chroma_config = ChromaDbConfig(allow_reset=True)
  78. app_config = AppConfig(collection_name=False, collect_metrics=False)
  79. app_with_settings = App(config=app_config, db_config=chroma_config)
  80. def test_duplicates_throw_warning(self, caplog):
  81. """
  82. Test that add duplicates throws an error.
  83. """
  84. # Start with a clean app
  85. self.app_with_settings.reset()
  86. app = App(config=AppConfig(collect_metrics=False))
  87. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  88. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  89. assert "Insert of existing embedding ID: 0" in caplog.text
  90. assert "Add of existing embedding ID: 0" in caplog.text
  91. def test_duplicates_collections_no_warning(self, caplog):
  92. """
  93. Test that different collections can have duplicates.
  94. """
  95. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  96. # Start with a clean app
  97. self.app_with_settings.reset()
  98. app = App(config=AppConfig(collect_metrics=False))
  99. app.set_collection_name("test_collection_1")
  100. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  101. app.set_collection_name("test_collection_2")
  102. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  103. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  104. assert "Add of existing embedding ID: 0" not in caplog.text # not
  105. class TestChromaDbCollection(unittest.TestCase):
  106. chroma_config = ChromaDbConfig(allow_reset=True)
  107. app_config = AppConfig(collection_name=False, collect_metrics=False)
  108. app_with_settings = App(config=app_config, db_config=chroma_config)
  109. def test_init_with_default_collection(self):
  110. """
  111. Test if the `App` instance is initialized with the correct default collection name.
  112. """
  113. app = App(config=AppConfig(collect_metrics=False))
  114. self.assertEqual(app.db.collection.name, "embedchain_store")
  115. def test_init_with_custom_collection(self):
  116. """
  117. Test if the `App` instance is initialized with the correct custom collection name.
  118. """
  119. config = AppConfig(collect_metrics=False)
  120. app = App(config=config)
  121. app.set_collection_name(name="test_collection")
  122. self.assertEqual(app.db.collection.name, "test_collection")
  123. def test_set_collection_name(self):
  124. """
  125. Test if the `App` collection is correctly switched using the `set_collection_name` method.
  126. """
  127. app = App(config=AppConfig(collect_metrics=False))
  128. app.set_collection_name("test_collection")
  129. self.assertEqual(app.db.collection.name, "test_collection")
  130. def test_changes_encapsulated(self):
  131. """
  132. Test that changes to one collection do not affect the other collection
  133. """
  134. # Start with a clean app
  135. self.app_with_settings.reset()
  136. app = App(config=AppConfig(collect_metrics=False))
  137. app.set_collection_name("test_collection_1")
  138. # Collection should be empty when created
  139. self.assertEqual(app.db.count(), 0)
  140. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  141. # After adding, should contain one item
  142. self.assertEqual(app.db.count(), 1)
  143. app.set_collection_name("test_collection_2")
  144. # New collection is empty
  145. self.assertEqual(app.db.count(), 0)
  146. # Adding to new collection should not effect existing collection
  147. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  148. app.set_collection_name("test_collection_1")
  149. # Should still be 1, not 2.
  150. self.assertEqual(app.db.count(), 1)
  151. def test_add_with_skip_embedding(self):
  152. """
  153. Test that changes to one collection do not affect the other collection
  154. """
  155. # Start with a clean app
  156. self.app_with_settings.reset()
  157. # app = App(config=AppConfig(collect_metrics=False), db=db)
  158. # Collection should be empty when created
  159. self.assertEqual(self.app_with_settings.db.count(), 0)
  160. self.app_with_settings.db.add(
  161. embeddings=[[0, 0, 0]],
  162. documents=["document"],
  163. metadatas=[{"value": "somevalue"}],
  164. ids=["id"],
  165. skip_embedding=True,
  166. )
  167. # After adding, should contain one item
  168. self.assertEqual(self.app_with_settings.db.count(), 1)
  169. # Validate if the get utility of the database is working as expected
  170. data = self.app_with_settings.db.get(["id"], limit=1)
  171. expected_value = {
  172. "documents": ["document"],
  173. "embeddings": None,
  174. "ids": ["id"],
  175. "metadatas": [{"value": "somevalue"}],
  176. }
  177. self.assertEqual(data, expected_value)
  178. # Validate if the query utility of the database is working as expected
  179. data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
  180. expected_value = ["document"]
  181. self.assertEqual(data, expected_value)
  182. def test_collections_are_persistent(self):
  183. """
  184. Test that a collection can be picked up later.
  185. """
  186. # Start with a clean app
  187. self.app_with_settings.reset()
  188. app = App(config=AppConfig(collect_metrics=False))
  189. app.set_collection_name("test_collection_1")
  190. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  191. del app
  192. app = App(config=AppConfig(collect_metrics=False))
  193. app.set_collection_name("test_collection_1")
  194. self.assertEqual(app.db.count(), 1)
  195. def test_parallel_collections(self):
  196. """
  197. Test that two apps can have different collections open in parallel.
  198. Switching the names will allow instant access to the collection of
  199. the other app.
  200. """
  201. # Start clean
  202. self.app_with_settings.reset()
  203. # Create two apps
  204. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  205. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  206. # app2 has been created last, but adding to app1 will still write to collection 1.
  207. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  208. self.assertEqual(app1.db.count(), 1)
  209. self.assertEqual(app2.db.count(), 0)
  210. # Add data
  211. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  212. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  213. # Swap names and test
  214. app1.set_collection_name("test_collection_2")
  215. self.assertEqual(app1.count(), 1)
  216. app2.set_collection_name("test_collection_1")
  217. self.assertEqual(app2.count(), 3)
  218. def test_ids_share_collections(self):
  219. """
  220. Different ids should still share collections.
  221. """
  222. # Start clean
  223. self.app_with_settings.reset()
  224. # Create two apps
  225. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  226. app1.set_collection_name("one_collection")
  227. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  228. app2.set_collection_name("one_collection")
  229. # Add data
  230. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  231. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  232. # Both should have the same collection
  233. self.assertEqual(app1.count(), 3)
  234. self.assertEqual(app2.count(), 3)
  235. def test_reset(self):
  236. """
  237. Resetting should hit all collections and ids.
  238. """
  239. # Start clean
  240. self.app_with_settings.reset()
  241. # Create four apps.
  242. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  243. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config)
  244. app1.set_collection_name("one_collection")
  245. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  246. app2.set_collection_name("one_collection")
  247. app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  248. app3.set_collection_name("three_collection")
  249. app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
  250. app4.set_collection_name("four_collection")
  251. # Each one of them get data
  252. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  253. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  254. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  255. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  256. # Resetting the first one should reset them all.
  257. app1.reset()
  258. # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
  259. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  260. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  261. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  262. # All should be empty
  263. self.assertEqual(app1.count(), 0)
  264. self.assertEqual(app2.count(), 0)
  265. self.assertEqual(app3.count(), 0)
  266. self.assertEqual(app4.count(), 0)