test_chroma_db.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import os
  2. import shutil
  3. from unittest.mock import patch
  4. import pytest
  5. from chromadb.config import Settings
  6. from embedchain import App
  7. from embedchain.config import AppConfig, ChromaDbConfig
  8. from embedchain.vectordb.chroma import ChromaDB
  9. os.environ["OPENAI_API_KEY"] = "test-api-key"
  10. @pytest.fixture
  11. def chroma_db():
  12. return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
  13. @pytest.fixture
  14. def app_with_settings():
  15. chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
  16. chroma_db = ChromaDB(config=chroma_config)
  17. app_config = AppConfig(collect_metrics=False)
  18. return App(config=app_config, db=chroma_db)
  19. @pytest.fixture(scope="session", autouse=True)
  20. def cleanup_db():
  21. yield
  22. try:
  23. shutil.rmtree("test-db")
  24. except OSError as e:
  25. print("Error: %s - %s." % (e.filename, e.strerror))
  26. @patch("embedchain.vectordb.chroma.chromadb.Client")
  27. def test_chroma_db_init_with_host_and_port(mock_client):
  28. chroma_db = ChromaDB(config=ChromaDbConfig(host="test-host", port="1234")) # noqa
  29. called_settings: Settings = mock_client.call_args[0][0]
  30. assert called_settings.chroma_server_host == "test-host"
  31. assert called_settings.chroma_server_http_port == "1234"
  32. @patch("embedchain.vectordb.chroma.chromadb.Client")
  33. def test_chroma_db_init_with_basic_auth(mock_client):
  34. chroma_config = {
  35. "host": "test-host",
  36. "port": "1234",
  37. "chroma_settings": {
  38. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  39. "chroma_client_auth_credentials": "admin:admin",
  40. },
  41. }
  42. ChromaDB(config=ChromaDbConfig(**chroma_config))
  43. called_settings: Settings = mock_client.call_args[0][0]
  44. assert called_settings.chroma_server_host == "test-host"
  45. assert called_settings.chroma_server_http_port == "1234"
  46. assert (
  47. called_settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
  48. )
  49. assert (
  50. called_settings.chroma_client_auth_credentials
  51. == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
  52. )
  53. @patch("embedchain.vectordb.chroma.chromadb.Client")
  54. def test_app_init_with_host_and_port(mock_client):
  55. host = "test-host"
  56. port = "1234"
  57. config = AppConfig(collect_metrics=False)
  58. db_config = ChromaDbConfig(host=host, port=port)
  59. db = ChromaDB(config=db_config)
  60. _app = App(config=config, db=db)
  61. called_settings: Settings = mock_client.call_args[0][0]
  62. assert called_settings.chroma_server_host == host
  63. assert called_settings.chroma_server_http_port == port
  64. @patch("embedchain.vectordb.chroma.chromadb.Client")
  65. def test_app_init_with_host_and_port_none(mock_client):
  66. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  67. _app = App(config=AppConfig(collect_metrics=False), db=db)
  68. called_settings: Settings = mock_client.call_args[0][0]
  69. assert called_settings.chroma_server_host is None
  70. assert called_settings.chroma_server_http_port is None
  71. def test_chroma_db_duplicates_throw_warning(caplog):
  72. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  73. app = App(config=AppConfig(collect_metrics=False), db=db)
  74. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  75. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  76. assert "Insert of existing embedding ID: 0" in caplog.text
  77. assert "Add of existing embedding ID: 0" in caplog.text
  78. app.db.reset()
  79. def test_chroma_db_duplicates_collections_no_warning(caplog):
  80. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  81. app = App(config=AppConfig(collect_metrics=False), db=db)
  82. app.set_collection_name("test_collection_1")
  83. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  84. app.set_collection_name("test_collection_2")
  85. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  86. assert "Insert of existing embedding ID: 0" not in caplog.text
  87. assert "Add of existing embedding ID: 0" not in caplog.text
  88. app.db.reset()
  89. app.set_collection_name("test_collection_1")
  90. app.db.reset()
  91. def test_chroma_db_collection_init_with_default_collection():
  92. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  93. app = App(config=AppConfig(collect_metrics=False), db=db)
  94. assert app.db.collection.name == "embedchain_store"
  95. def test_chroma_db_collection_init_with_custom_collection():
  96. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  97. app = App(config=AppConfig(collect_metrics=False), db=db)
  98. app.set_collection_name(name="test_collection")
  99. assert app.db.collection.name == "test_collection"
  100. def test_chroma_db_collection_set_collection_name():
  101. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  102. app = App(config=AppConfig(collect_metrics=False), db=db)
  103. app.set_collection_name("test_collection")
  104. assert app.db.collection.name == "test_collection"
  105. def test_chroma_db_collection_changes_encapsulated():
  106. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  107. app = App(config=AppConfig(collect_metrics=False), db=db)
  108. app.set_collection_name("test_collection_1")
  109. assert app.db.count() == 0
  110. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  111. assert app.db.count() == 1
  112. app.set_collection_name("test_collection_2")
  113. assert app.db.count() == 0
  114. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  115. app.set_collection_name("test_collection_1")
  116. assert app.db.count() == 1
  117. app.db.reset()
  118. app.set_collection_name("test_collection_2")
  119. app.db.reset()
  120. def test_chroma_db_collection_collections_are_persistent():
  121. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  122. app = App(config=AppConfig(collect_metrics=False), db=db)
  123. app.set_collection_name("test_collection_1")
  124. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  125. del app
  126. db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  127. app = App(config=AppConfig(collect_metrics=False), db=db)
  128. app.set_collection_name("test_collection_1")
  129. assert app.db.count() == 1
  130. app.db.reset()
  131. def test_chroma_db_collection_parallel_collections():
  132. db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
  133. app1 = App(
  134. config=AppConfig(collect_metrics=False),
  135. db=db1,
  136. )
  137. db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
  138. app2 = App(
  139. config=AppConfig(collect_metrics=False),
  140. db=db2,
  141. )
  142. # cleanup if any previous tests failed or were interrupted
  143. app1.db.reset()
  144. app2.db.reset()
  145. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  146. assert app1.db.count() == 1
  147. assert app2.db.count() == 0
  148. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  149. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  150. app1.set_collection_name("test_collection_2")
  151. assert app1.db.count() == 1
  152. app2.set_collection_name("test_collection_1")
  153. assert app2.db.count() == 3
  154. # cleanup
  155. app1.db.reset()
  156. app2.db.reset()
  157. def test_chroma_db_collection_ids_share_collections():
  158. db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  159. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  160. app1.set_collection_name("one_collection")
  161. db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  162. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  163. app2.set_collection_name("one_collection")
  164. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  165. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  166. assert app1.db.count() == 3
  167. assert app2.db.count() == 3
  168. # cleanup
  169. app1.db.reset()
  170. app2.db.reset()
  171. def test_chroma_db_collection_reset():
  172. db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  173. app1 = App(config=AppConfig(collect_metrics=False), db=db1)
  174. app1.set_collection_name("one_collection")
  175. db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  176. app2 = App(config=AppConfig(collect_metrics=False), db=db2)
  177. app2.set_collection_name("two_collection")
  178. db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  179. app3 = App(config=AppConfig(collect_metrics=False), db=db3)
  180. app3.set_collection_name("three_collection")
  181. db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  182. app4 = App(config=AppConfig(collect_metrics=False), db=db4)
  183. app4.set_collection_name("four_collection")
  184. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  185. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  186. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  187. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  188. app1.db.reset()
  189. assert app1.db.count() == 0
  190. assert app2.db.count() == 1
  191. assert app3.db.count() == 1
  192. assert app4.db.count() == 1
  193. # cleanup
  194. app2.db.reset()
  195. app3.db.reset()
  196. app4.db.reset()