test_chroma_db.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from chromadb.config import Settings
  5. from embedchain import App
  6. from embedchain.config import AppConfig, ChromaDbConfig
  7. from embedchain.vectordb.chroma import ChromaDB
  8. class TestChromaDbHosts(unittest.TestCase):
  9. def test_init_with_host_and_port(self):
  10. """
  11. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  12. """
  13. host = "test-host"
  14. port = "1234"
  15. config = ChromaDbConfig(host=host, port=port)
  16. db = ChromaDB(config=config)
  17. settings = db.client.get_settings()
  18. self.assertEqual(settings.chroma_server_host, host)
  19. self.assertEqual(settings.chroma_server_http_port, port)
  20. def test_init_with_basic_auth(self):
  21. host = "test-host"
  22. port = "1234"
  23. chroma_auth_settings = {
  24. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  25. "chroma_client_auth_credentials": "admin:admin",
  26. }
  27. config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
  28. db = ChromaDB(config=config)
  29. settings = db.client.get_settings()
  30. self.assertEqual(settings.chroma_server_host, host)
  31. self.assertEqual(settings.chroma_server_http_port, port)
  32. self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
  33. self.assertEqual(
  34. settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
  35. )
  36. # Review this test
  37. class TestChromaDbHostsInit(unittest.TestCase):
  38. @patch("embedchain.vectordb.chroma.chromadb.Client")
  39. def test_app_init_with_host_and_port(self, mock_client):
  40. """
  41. Test if the `App` instance is initialized with the correct host and port values.
  42. """
  43. host = "test-host"
  44. port = "1234"
  45. config = AppConfig(collect_metrics=False)
  46. chromadb_config = ChromaDbConfig(host=host, port=port)
  47. _app = App(config, chromadb_config=chromadb_config)
  48. called_settings: Settings = mock_client.call_args[0][0]
  49. self.assertEqual(called_settings.chroma_server_host, host)
  50. self.assertEqual(called_settings.chroma_server_http_port, port)
  51. class TestChromaDbHostsNone(unittest.TestCase):
  52. @patch("embedchain.vectordb.chroma.chromadb.Client")
  53. def test_init_with_host_and_port_none(self, mock_client):
  54. """
  55. Test if the `App` instance is initialized without default hosts and ports.
  56. """
  57. _app = App(config=AppConfig(collect_metrics=False))
  58. called_settings: Settings = mock_client.call_args[0][0]
  59. self.assertEqual(called_settings.chroma_server_host, None)
  60. self.assertEqual(called_settings.chroma_server_http_port, None)
  61. class TestChromaDbHostsLoglevel(unittest.TestCase):
  62. @patch("embedchain.vectordb.chroma.chromadb.Client")
  63. def test_init_with_host_and_port_log_level(self, mock_client):
  64. """
  65. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  66. """
  67. _app = App(config=AppConfig(collect_metrics=False))
  68. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  69. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  70. class TestChromaDbDuplicateHandling:
  71. chroma_config = ChromaDbConfig(allow_reset=True)
  72. app_config = AppConfig(collection_name=False, collect_metrics=False)
  73. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  74. def test_duplicates_throw_warning(self, caplog):
  75. """
  76. Test that add duplicates throws an error.
  77. """
  78. # Start with a clean app
  79. self.app_with_settings.reset()
  80. app = App(config=AppConfig(collect_metrics=False))
  81. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  82. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  83. assert "Insert of existing embedding ID: 0" in caplog.text
  84. assert "Add of existing embedding ID: 0" in caplog.text
  85. def test_duplicates_collections_no_warning(self, caplog):
  86. """
  87. Test that different collections can have duplicates.
  88. """
  89. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  90. # Start with a clean app
  91. self.app_with_settings.reset()
  92. app = App(config=AppConfig(collect_metrics=False))
  93. app.set_collection_name("test_collection_1")
  94. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  95. app.set_collection_name("test_collection_2")
  96. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  97. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  98. assert "Add of existing embedding ID: 0" not in caplog.text # not
  99. class TestChromaDbCollection(unittest.TestCase):
  100. chroma_config = ChromaDbConfig(allow_reset=True)
  101. app_config = AppConfig(collection_name=False, collect_metrics=False)
  102. app_with_settings = App(config=app_config, chromadb_config=chroma_config)
  103. def test_init_with_default_collection(self):
  104. """
  105. Test if the `App` instance is initialized with the correct default collection name.
  106. """
  107. app = App(config=AppConfig(collect_metrics=False))
  108. self.assertEqual(app.db.collection.name, "embedchain_store")
  109. def test_init_with_custom_collection(self):
  110. """
  111. Test if the `App` instance is initialized with the correct custom collection name.
  112. """
  113. config = AppConfig(collect_metrics=False)
  114. app = App(config=config)
  115. app.set_collection_name(name="test_collection")
  116. self.assertEqual(app.db.collection.name, "test_collection")
  117. def test_set_collection_name(self):
  118. """
  119. Test if the `App` collection is correctly switched using the `set_collection_name` method.
  120. """
  121. app = App(config=AppConfig(collect_metrics=False))
  122. app.set_collection_name("test_collection")
  123. self.assertEqual(app.db.collection.name, "test_collection")
  124. def test_changes_encapsulated(self):
  125. """
  126. Test that changes to one collection do not affect the other collection
  127. """
  128. # Start with a clean app
  129. self.app_with_settings.reset()
  130. app = App(config=AppConfig(collect_metrics=False))
  131. app.set_collection_name("test_collection_1")
  132. # Collection should be empty when created
  133. self.assertEqual(app.db.count(), 0)
  134. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  135. # After adding, should contain one item
  136. self.assertEqual(app.db.count(), 1)
  137. app.set_collection_name("test_collection_2")
  138. # New collection is empty
  139. self.assertEqual(app.db.count(), 0)
  140. # Adding to new collection should not effect existing collection
  141. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  142. app.set_collection_name("test_collection_1")
  143. # Should still be 1, not 2.
  144. self.assertEqual(app.db.count(), 1)
  145. def test_add_with_skip_embedding(self):
  146. """
  147. Test that changes to one collection do not affect the other collection
  148. """
  149. # Start with a clean app
  150. self.app_with_settings.reset()
  151. # app = App(config=AppConfig(collect_metrics=False), db=db)
  152. # Collection should be empty when created
  153. self.assertEqual(self.app_with_settings.db.count(), 0)
  154. self.app_with_settings.db.add(
  155. embeddings=[[0, 0, 0]],
  156. documents=["document"],
  157. metadatas=[{"value": "somevalue"}],
  158. ids=["id"],
  159. skip_embedding=True,
  160. )
  161. # After adding, should contain one item
  162. self.assertEqual(self.app_with_settings.db.count(), 1)
  163. # Validate if the get utility of the database is working as expected
  164. data = self.app_with_settings.db.get(["id"], limit=1)
  165. expected_value = {
  166. "documents": ["document"],
  167. "embeddings": None,
  168. "ids": ["id"],
  169. "metadatas": [{"value": "somevalue"}],
  170. }
  171. self.assertEqual(data, expected_value)
  172. # Validate if the query utility of the database is working as expected
  173. data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
  174. expected_value = ["document"]
  175. self.assertEqual(data, expected_value)
  176. def test_collections_are_persistent(self):
  177. """
  178. Test that a collection can be picked up later.
  179. """
  180. # Start with a clean app
  181. self.app_with_settings.reset()
  182. app = App(config=AppConfig(collect_metrics=False))
  183. app.set_collection_name("test_collection_1")
  184. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  185. del app
  186. app = App(config=AppConfig(collect_metrics=False))
  187. app.set_collection_name("test_collection_1")
  188. self.assertEqual(app.db.count(), 1)
  189. def test_parallel_collections(self):
  190. """
  191. Test that two apps can have different collections open in parallel.
  192. Switching the names will allow instant access to the collection of
  193. the other app.
  194. """
  195. # Start clean
  196. self.app_with_settings.reset()
  197. # Create two apps
  198. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  199. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  200. # app2 has been created last, but adding to app1 will still write to collection 1.
  201. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  202. self.assertEqual(app1.db.count(), 1)
  203. self.assertEqual(app2.db.count(), 0)
  204. # Add data
  205. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  206. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  207. # Swap names and test
  208. app1.set_collection_name("test_collection_2")
  209. self.assertEqual(app1.count(), 1)
  210. app2.set_collection_name("test_collection_1")
  211. self.assertEqual(app2.count(), 3)
  212. def test_ids_share_collections(self):
  213. """
  214. Different ids should still share collections.
  215. """
  216. # Start clean
  217. self.app_with_settings.reset()
  218. # Create two apps
  219. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  220. app1.set_collection_name("one_collection")
  221. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  222. app2.set_collection_name("one_collection")
  223. # Add data
  224. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  225. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  226. # Both should have the same collection
  227. self.assertEqual(app1.count(), 3)
  228. self.assertEqual(app2.count(), 3)
  229. def test_reset(self):
  230. """
  231. Resetting should hit all collections and ids.
  232. """
  233. # Start clean
  234. self.app_with_settings.reset()
  235. # Create four apps.
  236. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  237. app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
  238. app1.set_collection_name("one_collection")
  239. app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
  240. app2.set_collection_name("one_collection")
  241. app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
  242. app3.set_collection_name("three_collection")
  243. app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
  244. app4.set_collection_name("four_collection")
  245. # Each one of them get data
  246. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  247. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  248. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  249. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  250. # Resetting the first one should reset them all.
  251. app1.reset()
  252. # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
  253. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  254. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  255. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  256. # All should be empty
  257. self.assertEqual(app1.count(), 0)
  258. self.assertEqual(app2.count(), 0)
  259. self.assertEqual(app3.count(), 0)
  260. self.assertEqual(app4.count(), 0)