test_chroma_db.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import os
  2. import shutil
  3. from unittest.mock import patch
  4. import pytest
  5. from chromadb.config import Settings
  6. from embedchain import App
  7. from embedchain.config import AppConfig, ChromaDbConfig
  8. from embedchain.vectordb.chroma import ChromaDB
  9. os.environ["OPENAI_API_KEY"] = "test-api-key"
  10. @pytest.fixture
  11. def chroma_db():
  12. return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
  13. @pytest.fixture
  14. def app_with_settings():
  15. chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
  16. app_config = AppConfig(collect_metrics=False)
  17. return App(config=app_config, db_config=chroma_config)
  18. @pytest.fixture(scope="session", autouse=True)
  19. def cleanup_db():
  20. yield
  21. try:
  22. shutil.rmtree("test-db")
  23. except OSError as e:
  24. print("Error: %s - %s." % (e.filename, e.strerror))
  25. @pytest.mark.skip(reason="ChromaDB client needs to be mocked")
  26. def test_chroma_db_init_with_host_and_port(chroma_db):
  27. settings = chroma_db.client.get_settings()
  28. assert settings.chroma_server_host == "test-host"
  29. assert settings.chroma_server_http_port == "1234"
  30. @pytest.mark.skip(reason="ChromaDB client needs to be mocked")
  31. def test_chroma_db_init_with_basic_auth():
  32. chroma_config = {
  33. "host": "test-host",
  34. "port": "1234",
  35. "chroma_settings": {
  36. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  37. "chroma_client_auth_credentials": "admin:admin",
  38. },
  39. }
  40. db = ChromaDB(config=ChromaDbConfig(**chroma_config))
  41. settings = db.client.get_settings()
  42. assert settings.chroma_server_host == "test-host"
  43. assert settings.chroma_server_http_port == "1234"
  44. assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
  45. assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
  46. @patch("embedchain.vectordb.chroma.chromadb.Client")
  47. def test_app_init_with_host_and_port(mock_client):
  48. host = "test-host"
  49. port = "1234"
  50. config = AppConfig(collect_metrics=False)
  51. db_config = ChromaDbConfig(host=host, port=port)
  52. _app = App(config, db_config=db_config)
  53. called_settings: Settings = mock_client.call_args[0][0]
  54. assert called_settings.chroma_server_host == host
  55. assert called_settings.chroma_server_http_port == port
  56. @patch("embedchain.vectordb.chroma.chromadb.Client")
  57. def test_app_init_with_host_and_port_none(mock_client):
  58. _app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  59. called_settings: Settings = mock_client.call_args[0][0]
  60. assert called_settings.chroma_server_host is None
  61. assert called_settings.chroma_server_http_port is None
  62. def test_chroma_db_duplicates_throw_warning(caplog):
  63. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  64. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  65. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  66. assert "Insert of existing embedding ID: 0" in caplog.text
  67. assert "Add of existing embedding ID: 0" in caplog.text
  68. app.db.reset()
  69. def test_chroma_db_duplicates_collections_no_warning(caplog):
  70. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  71. app.set_collection_name("test_collection_1")
  72. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  73. app.set_collection_name("test_collection_2")
  74. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  75. assert "Insert of existing embedding ID: 0" not in caplog.text
  76. assert "Add of existing embedding ID: 0" not in caplog.text
  77. app.db.reset()
  78. app.set_collection_name("test_collection_1")
  79. app.db.reset()
  80. def test_chroma_db_collection_init_with_default_collection():
  81. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  82. assert app.db.collection.name == "embedchain_store"
  83. def test_chroma_db_collection_init_with_custom_collection():
  84. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  85. app.set_collection_name(name="test_collection")
  86. assert app.db.collection.name == "test_collection"
  87. def test_chroma_db_collection_set_collection_name():
  88. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  89. app.set_collection_name("test_collection")
  90. assert app.db.collection.name == "test_collection"
  91. def test_chroma_db_collection_changes_encapsulated():
  92. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  93. app.set_collection_name("test_collection_1")
  94. assert app.db.count() == 0
  95. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  96. assert app.db.count() == 1
  97. app.set_collection_name("test_collection_2")
  98. assert app.db.count() == 0
  99. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  100. app.set_collection_name("test_collection_1")
  101. assert app.db.count() == 1
  102. app.db.reset()
  103. app.set_collection_name("test_collection_2")
  104. app.db.reset()
  105. def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
  106. # Start with a clean app
  107. app_with_settings.db.reset()
  108. assert app_with_settings.db.count() == 0
  109. app_with_settings.db.add(
  110. embeddings=[[0, 0, 0]],
  111. documents=["document"],
  112. metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
  113. ids=["id"],
  114. skip_embedding=True,
  115. )
  116. assert app_with_settings.db.count() == 1
  117. data = app_with_settings.db.get(["id"], limit=1)
  118. expected_value = {
  119. "documents": ["document"],
  120. "embeddings": None,
  121. "ids": ["id"],
  122. "metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
  123. "data": None,
  124. "uris": None,
  125. }
  126. assert data == expected_value
  127. data_without_citations = app_with_settings.db.query(
  128. input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
  129. )
  130. expected_value_without_citations = ["document"]
  131. assert data_without_citations == expected_value_without_citations
  132. app_with_settings.db.reset()
  133. def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
  134. # Start with a clean app
  135. app_with_settings.db.reset()
  136. assert app_with_settings.db.count() == 0
  137. with pytest.raises(ValueError):
  138. app_with_settings.db.add(
  139. embeddings=[[0, 0, 0]],
  140. documents=["document", "document2"],
  141. metadatas=[{"value": "somevalue"}],
  142. ids=["id"],
  143. skip_embedding=True,
  144. )
  145. assert app_with_settings.db.count() == 0
  146. with pytest.raises(ValueError):
  147. app_with_settings.db.add(
  148. embeddings=None,
  149. documents=["document", "document2"],
  150. metadatas=[{"value": "somevalue"}],
  151. ids=["id"],
  152. skip_embedding=True,
  153. )
  154. assert app_with_settings.db.count() == 0
  155. app_with_settings.db.reset()
  156. def test_chroma_db_collection_collections_are_persistent():
  157. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  158. app.set_collection_name("test_collection_1")
  159. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  160. del app
  161. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  162. app.set_collection_name("test_collection_1")
  163. assert app.db.count() == 1
  164. app.db.reset()
  165. def test_chroma_db_collection_parallel_collections():
  166. app1 = App(
  167. AppConfig(collection_name="test_collection_1", collect_metrics=False),
  168. db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
  169. )
  170. app2 = App(
  171. AppConfig(collection_name="test_collection_2", collect_metrics=False),
  172. db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
  173. )
  174. # cleanup if any previous tests failed or were interrupted
  175. app1.db.reset()
  176. app2.db.reset()
  177. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  178. assert app1.db.count() == 1
  179. assert app2.db.count() == 0
  180. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  181. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  182. app1.set_collection_name("test_collection_2")
  183. assert app1.db.count() == 1
  184. app2.set_collection_name("test_collection_1")
  185. assert app2.db.count() == 3
  186. # cleanup
  187. app1.db.reset()
  188. app2.db.reset()
  189. def test_chroma_db_collection_ids_share_collections():
  190. app1 = App(
  191. AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  192. )
  193. app1.set_collection_name("one_collection")
  194. app2 = App(
  195. AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  196. )
  197. app2.set_collection_name("one_collection")
  198. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  199. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  200. assert app1.db.count() == 3
  201. assert app2.db.count() == 3
  202. # cleanup
  203. app1.db.reset()
  204. app2.db.reset()
  205. def test_chroma_db_collection_reset():
  206. app1 = App(
  207. AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  208. )
  209. app1.set_collection_name("one_collection")
  210. app2 = App(
  211. AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  212. )
  213. app2.set_collection_name("two_collection")
  214. app3 = App(
  215. AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  216. )
  217. app3.set_collection_name("three_collection")
  218. app4 = App(
  219. AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  220. )
  221. app4.set_collection_name("four_collection")
  222. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  223. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  224. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  225. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  226. app1.db.reset()
  227. assert app1.db.count() == 0
  228. assert app2.db.count() == 1
  229. assert app3.db.count() == 1
  230. assert app4.db.count() == 1
  231. # cleanup
  232. app2.db.reset()
  233. app3.db.reset()
  234. app4.db.reset()
  235. def test_chroma_db_collection_query(app_with_settings):
  236. app_with_settings.db.reset()
  237. assert app_with_settings.db.count() == 0
  238. app_with_settings.db.add(
  239. embeddings=[[0, 0, 0]],
  240. documents=["document"],
  241. metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
  242. ids=["id"],
  243. skip_embedding=True,
  244. )
  245. assert app_with_settings.db.count() == 1
  246. app_with_settings.db.add(
  247. embeddings=[[0, 1, 0]],
  248. documents=["document2"],
  249. metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
  250. ids=["id2"],
  251. skip_embedding=True,
  252. )
  253. assert app_with_settings.db.count() == 2
  254. data_without_citations = app_with_settings.db.query(
  255. input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
  256. )
  257. expected_value_without_citations = ["document", "document2"]
  258. assert data_without_citations == expected_value_without_citations
  259. data_with_citations = app_with_settings.db.query(
  260. input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
  261. )
  262. expected_value_with_citations = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
  263. assert data_with_citations == expected_value_with_citations
  264. app_with_settings.db.reset()