test_chroma_db.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig, CustomAppConfig
  6. from embedchain.models import EmbeddingFunctions, Providers
  7. from embedchain.vectordb.chroma_db import ChromaDB
  8. class TestChromaDbHosts(unittest.TestCase):
  9. def test_init_with_host_and_port(self):
  10. """
  11. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  12. """
  13. host = "test-host"
  14. port = "1234"
  15. db = ChromaDB(host=host, port=port, embedding_fn=len)
  16. settings = db.client.get_settings()
  17. self.assertEqual(settings.chroma_server_host, host)
  18. self.assertEqual(settings.chroma_server_http_port, port)
  19. def test_init_with_basic_auth(self):
  20. host = "test-host"
  21. port = "1234"
  22. chroma_auth_settings = {
  23. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  24. "chroma_client_auth_credentials": "admin:admin",
  25. }
  26. db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
  27. settings = db.client.get_settings()
  28. self.assertEqual(settings.chroma_server_host, host)
  29. self.assertEqual(settings.chroma_server_http_port, port)
  30. self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
  31. self.assertEqual(
  32. settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
  33. )
  34. # Review this test
  35. class TestChromaDbHostsInit(unittest.TestCase):
  36. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  37. def test_init_with_host_and_port(self, mock_client):
  38. """
  39. Test if the `App` instance is initialized with the correct host and port values.
  40. """
  41. host = "test-host"
  42. port = "1234"
  43. config = AppConfig(host=host, port=port, collect_metrics=False)
  44. _app = App(config)
  45. # self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
  46. # self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
  47. class TestChromaDbHostsNone(unittest.TestCase):
  48. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  49. def test_init_with_host_and_port(self, mock_client):
  50. """
  51. Test if the `App` instance is initialized without default hosts and ports.
  52. """
  53. _app = App(config=AppConfig(collect_metrics=False))
  54. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  55. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  56. class TestChromaDbHostsLoglevel(unittest.TestCase):
  57. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  58. def test_init_with_host_and_port(self, mock_client):
  59. """
  60. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  61. """
  62. config = AppConfig(log_level="DEBUG")
  63. _app = App(config=AppConfig(collect_metrics=False))
  64. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  65. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  66. class TestChromaDbDuplicateHandling:
  67. app_with_settings = App(
  68. CustomAppConfig(
  69. provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
  70. )
  71. )
  72. def test_duplicates_throw_warning(self, caplog):
  73. """
  74. Test that add duplicates throws an error.
  75. """
  76. # Start with a clean app
  77. self.app_with_settings.reset()
  78. app = App(config=AppConfig(collect_metrics=False))
  79. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  80. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  81. assert "Insert of existing embedding ID: 0" in caplog.text
  82. assert "Add of existing embedding ID: 0" in caplog.text
  83. def test_duplicates_collections_no_warning(self, caplog):
  84. """
  85. Test that different collections can have duplicates.
  86. """
  87. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  88. # Start with a clean app
  89. self.app_with_settings.reset()
  90. app = App(config=AppConfig(collect_metrics=False))
  91. app.set_collection("test_collection_1")
  92. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  93. app.set_collection("test_collection_2")
  94. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  95. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  96. assert "Add of existing embedding ID: 0" not in caplog.text # not
  97. class TestChromaDbCollection(unittest.TestCase):
  98. app_with_settings = App(
  99. CustomAppConfig(
  100. provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
  101. )
  102. )
  103. def test_init_with_default_collection(self):
  104. """
  105. Test if the `App` instance is initialized with the correct default collection name.
  106. """
  107. app = App(config=AppConfig(collect_metrics=False))
  108. self.assertEqual(app.collection.name, "embedchain_store")
  109. def test_init_with_custom_collection(self):
  110. """
  111. Test if the `App` instance is initialized with the correct custom collection name.
  112. """
  113. config = AppConfig(collection_name="test_collection", collect_metrics=False)
  114. app = App(config)
  115. self.assertEqual(app.collection.name, "test_collection")
  116. def test_set_collection(self):
  117. """
  118. Test if the `App` collection is correctly switched using the `set_collection` method.
  119. """
  120. app = App(config=AppConfig(collect_metrics=False))
  121. app.set_collection("test_collection")
  122. self.assertEqual(app.collection.name, "test_collection")
  123. def test_changes_encapsulated(self):
  124. """
  125. Test that changes to one collection do not affect the other collection
  126. """
  127. # Start with a clean app
  128. self.app_with_settings.reset()
  129. app = App(config=AppConfig(collect_metrics=False))
  130. app.set_collection("test_collection_1")
  131. # Collection should be empty when created
  132. self.assertEqual(app.count(), 0)
  133. app.collection.add(embeddings=[0, 0, 0], ids=["0"])
  134. # After adding, should contain one item
  135. self.assertEqual(app.count(), 1)
  136. app.set_collection("test_collection_2")
  137. # New collection is empty
  138. self.assertEqual(app.count(), 0)
  139. # Adding to new collection should not effect existing collection
  140. app.collection.add(embeddings=[0, 0, 0], ids=["0"])
  141. app.set_collection("test_collection_1")
  142. # Should still be 1, not 2.
  143. self.assertEqual(app.count(), 1)
  144. def test_collections_are_persistent(self):
  145. """
  146. Test that a collection can be picked up later.
  147. """
  148. # Start with a clean app
  149. self.app_with_settings.reset()
  150. app = App(config=AppConfig(collect_metrics=False))
  151. app.set_collection("test_collection_1")
  152. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  153. del app
  154. app = App(config=AppConfig(collect_metrics=False))
  155. app.set_collection("test_collection_1")
  156. self.assertEqual(app.count(), 1)
  157. def test_parallel_collections(self):
  158. """
  159. Test that two apps can have different collections open in parallel.
  160. Switching the names will allow instant access to the collection of
  161. the other app.
  162. """
  163. # Start clean
  164. self.app_with_settings.reset()
  165. # Create two apps
  166. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  167. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  168. # app2 has been created last, but adding to app1 will still write to collection 1.
  169. app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
  170. self.assertEqual(app1.count(), 1)
  171. self.assertEqual(app2.count(), 0)
  172. # Add data
  173. app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  174. app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
  175. # Swap names and test
  176. app1.set_collection("test_collection_2")
  177. self.assertEqual(app1.count(), 1)
  178. app2.set_collection("test_collection_1")
  179. self.assertEqual(app2.count(), 3)
  180. def test_ids_share_collections(self):
  181. """
  182. Different ids should still share collections.
  183. """
  184. # Start clean
  185. self.app_with_settings.reset()
  186. # Create two apps
  187. app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
  188. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  189. # Add data
  190. app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  191. app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
  192. # Both should have the same collection
  193. self.assertEqual(app1.count(), 3)
  194. self.assertEqual(app2.count(), 3)
  195. def test_reset(self):
  196. """
  197. Resetting should hit all collections and ids.
  198. """
  199. # Start clean
  200. self.app_with_settings.reset()
  201. # Create four apps.
  202. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  203. app1 = App(
  204. CustomAppConfig(
  205. collection_name="one_collection",
  206. id="new_app_id_1",
  207. collect_metrics=False,
  208. provider=Providers.OPENAI,
  209. embedding_fn=EmbeddingFunctions.OPENAI,
  210. chroma_settings={"allow_reset": True},
  211. )
  212. )
  213. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  214. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
  215. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
  216. # Each one of them get data
  217. app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
  218. app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
  219. app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
  220. app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
  221. # Resetting the first one should reset them all.
  222. app1.reset()
  223. # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
  224. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  225. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  226. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  227. # All should be empty
  228. self.assertEqual(app1.count(), 0)
  229. self.assertEqual(app2.count(), 0)
  230. self.assertEqual(app3.count(), 0)
  231. self.assertEqual(app4.count(), 0)