test_chroma_db.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # ruff: noqa: E501
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig
  6. from embedchain.vectordb.chroma_db import ChromaDB
  7. class TestChromaDbHosts(unittest.TestCase):
  8. def test_init_with_host_and_port(self):
  9. """
  10. Test if the `ChromaDB` instance is initialized with the correct host and port values.
  11. """
  12. host = "test-host"
  13. port = "1234"
  14. db = ChromaDB(host=host, port=port, embedding_fn=len)
  15. settings = db.client.get_settings()
  16. self.assertEqual(settings.chroma_server_host, host)
  17. self.assertEqual(settings.chroma_server_http_port, port)
  18. # Review this test
  19. class TestChromaDbHostsInit(unittest.TestCase):
  20. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  21. def test_init_with_host_and_port(self, mock_client):
  22. """
  23. Test if the `App` instance is initialized with the correct host and port values.
  24. """
  25. host = "test-host"
  26. port = "1234"
  27. config = AppConfig(host=host, port=port, collect_metrics=False)
  28. _app = App(config)
  29. # self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
  30. # self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
  31. class TestChromaDbHostsNone(unittest.TestCase):
  32. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  33. def test_init_with_host_and_port(self, mock_client):
  34. """
  35. Test if the `App` instance is initialized without default hosts and ports.
  36. """
  37. _app = App(config=AppConfig(collect_metrics=False))
  38. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  39. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  40. class TestChromaDbHostsLoglevel(unittest.TestCase):
  41. @patch("embedchain.vectordb.chroma_db.chromadb.Client")
  42. def test_init_with_host_and_port(self, mock_client):
  43. """
  44. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
  45. """
  46. config = AppConfig(log_level="DEBUG")
  47. _app = App(config=AppConfig(collect_metrics=False))
  48. self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
  49. self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
  50. class TestChromaDbDuplicateHandling:
  51. def test_duplicates_throw_warning(self, caplog):
  52. """
  53. Test that add duplicates throws an error.
  54. """
  55. # Start with a clean app
  56. App().reset()
  57. app = App(config=AppConfig(collect_metrics=False))
  58. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  59. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  60. assert "Insert of existing embedding ID: 0" in caplog.text
  61. assert "Add of existing embedding ID: 0" in caplog.text
  62. def test_duplicates_collections_no_warning(self, caplog):
  63. """
  64. Test that different collections can have duplicates.
  65. """
  66. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
  67. # Start with a clean app
  68. App().reset()
  69. app = App(config=AppConfig(collect_metrics=False))
  70. app.set_collection("test_collection_1")
  71. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  72. app.set_collection("test_collection_2")
  73. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  74. assert "Insert of existing embedding ID: 0" not in caplog.text # not
  75. assert "Add of existing embedding ID: 0" not in caplog.text # not
  76. class TestChromaDbCollection(unittest.TestCase):
  77. def test_init_with_default_collection(self):
  78. """
  79. Test if the `App` instance is initialized with the correct default collection name.
  80. """
  81. app = App(config=AppConfig(collect_metrics=False))
  82. self.assertEqual(app.collection.name, "embedchain_store")
  83. def test_init_with_custom_collection(self):
  84. """
  85. Test if the `App` instance is initialized with the correct custom collection name.
  86. """
  87. config = AppConfig(collection_name="test_collection", collect_metrics=False)
  88. app = App(config)
  89. self.assertEqual(app.collection.name, "test_collection")
  90. def test_set_collection(self):
  91. """
  92. Test if the `App` collection is correctly switched using the `set_collection` method.
  93. """
  94. app = App(config=AppConfig(collect_metrics=False))
  95. app.set_collection("test_collection")
  96. self.assertEqual(app.collection.name, "test_collection")
  97. def test_changes_encapsulated(self):
  98. """
  99. Test that changes to one collection do not affect the other collection
  100. """
  101. # Start with a clean app
  102. App().reset()
  103. app = App(config=AppConfig(collect_metrics=False))
  104. app.set_collection("test_collection_1")
  105. # Collection should be empty when created
  106. self.assertEqual(app.count(), 0)
  107. app.collection.add(embeddings=[0, 0, 0], ids=["0"])
  108. # After adding, should contain one item
  109. self.assertEqual(app.count(), 1)
  110. app.set_collection("test_collection_2")
  111. # New collection is empty
  112. self.assertEqual(app.count(), 0)
  113. # Adding to new collection should not effect existing collection
  114. app.collection.add(embeddings=[0, 0, 0], ids=["0"])
  115. app.set_collection("test_collection_1")
  116. # Should still be 1, not 2.
  117. self.assertEqual(app.count(), 1)
  118. def test_collections_are_persistent(self):
  119. """
  120. Test that a collection can be picked up later.
  121. """
  122. # Start with a clean app
  123. App().reset()
  124. app = App(config=AppConfig(collect_metrics=False))
  125. app.set_collection("test_collection_1")
  126. app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  127. del app
  128. app = App(config=AppConfig(collect_metrics=False))
  129. app.set_collection("test_collection_1")
  130. self.assertEqual(app.count(), 1)
  131. def test_parallel_collections(self):
  132. """
  133. Test that two apps can have different collections open in parallel.
  134. Switching the names will allow instant access to the collection of
  135. the other app.
  136. """
  137. # Start clean
  138. App().reset()
  139. # Create two apps
  140. app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
  141. app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
  142. # app2 has been created last, but adding to app1 will still write to collection 1.
  143. app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
  144. self.assertEqual(app1.count(), 1)
  145. self.assertEqual(app2.count(), 0)
  146. # Add data
  147. app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  148. app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
  149. # Swap names and test
  150. app1.set_collection("test_collection_2")
  151. self.assertEqual(app1.count(), 1)
  152. app2.set_collection("test_collection_1")
  153. self.assertEqual(app2.count(), 3)
  154. def test_ids_share_collections(self):
  155. """
  156. Different ids should still share collections.
  157. """
  158. # Start clean
  159. App().reset()
  160. # Create two apps
  161. app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
  162. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  163. # Add data
  164. app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  165. app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
  166. # Both should have the same collection
  167. self.assertEqual(app1.count(), 3)
  168. self.assertEqual(app2.count(), 3)
  169. def test_reset(self):
  170. """
  171. Resetting should hit all collections and ids.
  172. """
  173. # Start clean
  174. App().reset()
  175. # Create four apps.
  176. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
  177. app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
  178. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  179. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
  180. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
  181. # Each one of them get data
  182. app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
  183. app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
  184. app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
  185. app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
  186. # Resetting the first one should reset them all.
  187. app1.reset()
  188. # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
  189. app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
  190. app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
  191. app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
  192. # All should be empty
  193. self.assertEqual(app1.count(), 0)
  194. self.assertEqual(app2.count(), 0)
  195. self.assertEqual(app3.count(), 0)
  196. self.assertEqual(app4.count(), 0)