test_chroma_db.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import os
  2. import shutil
  3. from unittest.mock import patch
  4. import pytest
  5. from chromadb.config import Settings
  6. from embedchain import App
  7. from embedchain.config import AppConfig, ChromaDbConfig
  8. from embedchain.vectordb.chroma import ChromaDB
  9. os.environ["OPENAI_API_KEY"] = "test-api-key"
  10. @pytest.fixture
  11. def chroma_db():
  12. return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
  13. @pytest.fixture
  14. def app_with_settings():
  15. chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
  16. chroma_db = ChromaDB(config=chroma_config)
  17. app_config = AppConfig(collect_metrics=False)
  18. return App(config=app_config, db=chroma_db)
  19. @pytest.fixture(scope="session", autouse=True)
  20. def cleanup_db():
  21. yield
  22. try:
  23. shutil.rmtree("test-db")
  24. except OSError as e:
  25. print("Error: %s - %s." % (e.filename, e.strerror))
  26. @pytest.mark.skip(reason="ChromaDB client needs to be mocked")
  27. def test_chroma_db_init_with_host_and_port(chroma_db):
  28. settings = chroma_db.client.get_settings()
  29. assert settings.chroma_server_host == "test-host"
  30. assert settings.chroma_server_http_port == "1234"
  31. @pytest.mark.skip(reason="ChromaDB client needs to be mocked")
  32. def test_chroma_db_init_with_basic_auth():
  33. chroma_config = {
  34. "host": "test-host",
  35. "port": "1234",
  36. "chroma_settings": {
  37. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  38. "chroma_client_auth_credentials": "admin:admin",
  39. },
  40. }
  41. db = ChromaDB(config=ChromaDbConfig(**chroma_config))
  42. settings = db.client.get_settings()
  43. assert settings.chroma_server_host == "test-host"
  44. assert settings.chroma_server_http_port == "1234"
  45. assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
  46. assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
  47. @patch("embedchain.vectordb.chroma.chromadb.Client")
  48. def test_app_init_with_host_and_port(mock_client):
  49. host = "test-host"
  50. port = "1234"
  51. config = AppConfig(collect_metrics=False)
  52. db_config = ChromaDbConfig(host=host, port=port)
  53. db = ChromaDB(config=db_config)
  54. _app = App(config=config, db=db)
  55. called_settings: Settings = mock_client.call_args[0][0]
  56. assert called_settings.chroma_server_host == host
  57. assert called_settings.chroma_server_http_port == port
  58. @patch("embedchain.vectordb.chroma.chromadb.Client")
  59. def test_app_init_with_host_and_port_none(mock_client):
  60. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  61. _app = App(config=AppConfig(collect_metrics=False), db=db)
  62. called_settings: Settings = mock_client.call_args[0][0]
  63. assert called_settings.chroma_server_host is None
  64. assert called_settings.chroma_server_http_port is None
  65. @pytest.mark.skip(reason="Logging setup needs to be fixed to make this test to work")
  66. def test_chroma_db_duplicates_throw_warning(caplog):
  67. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  68. app = App(config=AppConfig(collect_metrics=False), db=db)
  69. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  70. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  71. assert "Insert of existing embedding ID: 0" in caplog.text
  72. assert "Add of existing embedding ID: 0" in caplog.text
  73. app.db.reset()
  74. def test_chroma_db_duplicates_collections_no_warning(caplog):
  75. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  76. app = App(config=AppConfig(collect_metrics=False), db=db)
  77. app.set_collection_name("test_collection_1")
  78. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  79. app.set_collection_name("test_collection_2")
  80. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  81. assert "Insert of existing embedding ID: 0" not in caplog.text
  82. assert "Add of existing embedding ID: 0" not in caplog.text
  83. app.db.reset()
  84. app.set_collection_name("test_collection_1")
  85. app.db.reset()
  86. def test_chroma_db_collection_init_with_default_collection():
  87. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  88. app = App(config=AppConfig(collect_metrics=False), db=db)
  89. assert app.db.collection.name == "embedchain_store"
  90. def test_chroma_db_collection_init_with_custom_collection():
  91. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  92. app = App(config=AppConfig(collect_metrics=False), db=db)
  93. app.set_collection_name(name="test_collection")
  94. assert app.db.collection.name == "test_collection"
  95. def test_chroma_db_collection_set_collection_name():
  96. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  97. app = App(config=AppConfig(collect_metrics=False), db=db)
  98. app.set_collection_name("test_collection")
  99. assert app.db.collection.name == "test_collection"
  100. def test_chroma_db_collection_changes_encapsulated():
  101. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  102. app = App(config=AppConfig(collect_metrics=False), db=db)
  103. app.set_collection_name("test_collection_1")
  104. assert app.db.count() == 0
  105. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  106. assert app.db.count() == 1
  107. app.set_collection_name("test_collection_2")
  108. assert app.db.count() == 0
  109. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  110. app.set_collection_name("test_collection_1")
  111. assert app.db.count() == 1
  112. app.db.reset()
  113. app.set_collection_name("test_collection_2")
  114. app.db.reset()
  115. def test_chroma_db_collection_collections_are_persistent():
  116. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  117. app = App(config=AppConfig(collect_metrics=False), db=db)
  118. app.set_collection_name("test_collection_1")
  119. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  120. del app
  121. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  122. app = App(config=AppConfig(collect_metrics=False), db=db)
  123. app.set_collection_name("test_collection_1")
  124. assert app.db.count() == 1
  125. app.db.reset()
  126. def test_chroma_db_collection_parallel_collections():
  127. db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
  128. app1 = App(
  129. config=AppConfig(collect_metrics=False),
  130. db=db1,
  131. )
  132. db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
  133. app2 = App(
  134. config=AppConfig(collect_metrics=False),
  135. db=db2,
  136. )
  137. # cleanup if any previous tests failed or were interrupted
  138. app1.db.reset()
  139. app2.db.reset()
  140. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  141. assert app1.db.count() == 1
  142. assert app2.db.count() == 0
  143. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  144. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  145. app1.set_collection_name("test_collection_2")
  146. assert app1.db.count() == 1
  147. app2.set_collection_name("test_collection_1")
  148. assert app2.db.count() == 3
  149. # cleanup
  150. app1.db.reset()
  151. app2.db.reset()
  152. def test_chroma_db_collection_ids_share_collections():
  153. db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  154. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  155. app1.set_collection_name("one_collection")
  156. db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  157. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  158. app2.set_collection_name("one_collection")
  159. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  160. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  161. assert app1.db.count() == 3
  162. assert app2.db.count() == 3
  163. # cleanup
  164. app1.db.reset()
  165. app2.db.reset()
  166. def test_chroma_db_collection_reset():
  167. db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  168. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  169. app1.set_collection_name("one_collection")
  170. db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  171. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  172. app2.set_collection_name("two_collection")
  173. db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  174. app3 = App(config=AppConfig(collect_metrics=False), db=db3)
  175. app3.set_collection_name("three_collection")
  176. db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  177. app4 = App(config=AppConfig(collect_metrics=False), db=db4)
  178. app4.set_collection_name("four_collection")
  179. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  180. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  181. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  182. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  183. app1.db.reset()
  184. assert app1.db.count() == 0
  185. assert app2.db.count() == 1
  186. assert app3.db.count() == 1
  187. assert app4.db.count() == 1
  188. # cleanup
  189. app2.db.reset()
  190. app3.db.reset()
  191. app4.db.reset()