test_chroma_db.py 10 KB


  1. import os
  2. import shutil
  3. import pytest
  4. from unittest.mock import patch
  5. from chromadb.config import Settings
  6. from embedchain import App
  7. from embedchain.config import AppConfig, ChromaDbConfig
  8. from embedchain.vectordb.chroma import ChromaDB
  9. os.environ["OPENAI_API_KEY"] = "test-api-key"
  10. @pytest.fixture
  11. def chroma_db():
  12. return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
  13. @pytest.fixture
  14. def app_with_settings():
  15. chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
  16. app_config = AppConfig(collect_metrics=False)
  17. return App(config=app_config, db_config=chroma_config)
  18. @pytest.fixture(scope="session", autouse=True)
  19. def cleanup_db():
  20. yield
  21. try:
  22. shutil.rmtree("test-db")
  23. except OSError as e:
  24. print("Error: %s - %s." % (e.filename, e.strerror))
  25. def test_chroma_db_init_with_host_and_port(chroma_db):
  26. settings = chroma_db.client.get_settings()
  27. assert settings.chroma_server_host == "test-host"
  28. assert settings.chroma_server_http_port == "1234"
  29. def test_chroma_db_init_with_basic_auth():
  30. chroma_config = {
  31. "host": "test-host",
  32. "port": "1234",
  33. "chroma_settings": {
  34. "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
  35. "chroma_client_auth_credentials": "admin:admin",
  36. },
  37. }
  38. db = ChromaDB(config=ChromaDbConfig(**chroma_config))
  39. settings = db.client.get_settings()
  40. assert settings.chroma_server_host == "test-host"
  41. assert settings.chroma_server_http_port == "1234"
  42. assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
  43. assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
  44. @patch("embedchain.vectordb.chroma.chromadb.Client")
  45. def test_app_init_with_host_and_port(mock_client):
  46. host = "test-host"
  47. port = "1234"
  48. config = AppConfig(collect_metrics=False)
  49. db_config = ChromaDbConfig(host=host, port=port)
  50. _app = App(config, db_config=db_config)
  51. called_settings: Settings = mock_client.call_args[0][0]
  52. assert called_settings.chroma_server_host == host
  53. assert called_settings.chroma_server_http_port == port
  54. @patch("embedchain.vectordb.chroma.chromadb.Client")
  55. def test_app_init_with_host_and_port_none(mock_client):
  56. _app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  57. called_settings: Settings = mock_client.call_args[0][0]
  58. assert called_settings.chroma_server_host is None
  59. assert called_settings.chroma_server_http_port is None
  60. def test_chroma_db_duplicates_throw_warning(caplog):
  61. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  62. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  63. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  64. assert "Insert of existing embedding ID: 0" in caplog.text
  65. assert "Add of existing embedding ID: 0" in caplog.text
  66. app.db.reset()
  67. def test_chroma_db_duplicates_collections_no_warning(caplog):
  68. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  69. app.set_collection_name("test_collection_1")
  70. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  71. app.set_collection_name("test_collection_2")
  72. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  73. assert "Insert of existing embedding ID: 0" not in caplog.text
  74. assert "Add of existing embedding ID: 0" not in caplog.text
  75. app.db.reset()
  76. app.set_collection_name("test_collection_1")
  77. app.db.reset()
  78. def test_chroma_db_collection_init_with_default_collection():
  79. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  80. assert app.db.collection.name == "embedchain_store"
  81. def test_chroma_db_collection_init_with_custom_collection():
  82. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  83. app.set_collection_name(name="test_collection")
  84. assert app.db.collection.name == "test_collection"
  85. def test_chroma_db_collection_set_collection_name():
  86. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  87. app.set_collection_name("test_collection")
  88. assert app.db.collection.name == "test_collection"
  89. def test_chroma_db_collection_changes_encapsulated():
  90. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  91. app.set_collection_name("test_collection_1")
  92. assert app.db.count() == 0
  93. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  94. assert app.db.count() == 1
  95. app.set_collection_name("test_collection_2")
  96. assert app.db.count() == 0
  97. app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  98. app.set_collection_name("test_collection_1")
  99. assert app.db.count() == 1
  100. app.db.reset()
  101. app.set_collection_name("test_collection_2")
  102. app.db.reset()
  103. def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
  104. # Start with a clean app
  105. app_with_settings.db.reset()
  106. assert app_with_settings.db.count() == 0
  107. app_with_settings.db.add(
  108. embeddings=[[0, 0, 0]],
  109. documents=["document"],
  110. metadatas=[{"value": "somevalue"}],
  111. ids=["id"],
  112. skip_embedding=True,
  113. )
  114. assert app_with_settings.db.count() == 1
  115. data = app_with_settings.db.get(["id"], limit=1)
  116. expected_value = {
  117. "documents": ["document"],
  118. "embeddings": None,
  119. "ids": ["id"],
  120. "metadatas": [{"value": "somevalue"}],
  121. }
  122. assert data == expected_value
  123. data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
  124. expected_value = ["document"]
  125. assert data == expected_value
  126. app_with_settings.db.reset()
  127. def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
  128. # Start with a clean app
  129. app_with_settings.db.reset()
  130. assert app_with_settings.db.count() == 0
  131. with pytest.raises(ValueError):
  132. app_with_settings.db.add(
  133. embeddings=[[0, 0, 0]],
  134. documents=["document", "document2"],
  135. metadatas=[{"value": "somevalue"}],
  136. ids=["id"],
  137. skip_embedding=True,
  138. )
  139. assert app_with_settings.db.count() == 0
  140. with pytest.raises(ValueError):
  141. app_with_settings.db.add(
  142. embeddings=None,
  143. documents=["document", "document2"],
  144. metadatas=[{"value": "somevalue"}],
  145. ids=["id"],
  146. skip_embedding=True,
  147. )
  148. assert app_with_settings.db.count() == 0
  149. app_with_settings.db.reset()
  150. def test_chroma_db_collection_collections_are_persistent():
  151. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  152. app.set_collection_name("test_collection_1")
  153. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
  154. del app
  155. app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
  156. app.set_collection_name("test_collection_1")
  157. assert app.db.count() == 1
  158. app.db.reset()
  159. def test_chroma_db_collection_parallel_collections():
  160. app1 = App(
  161. AppConfig(collection_name="test_collection_1", collect_metrics=False),
  162. db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
  163. )
  164. app2 = App(
  165. AppConfig(collection_name="test_collection_2", collect_metrics=False),
  166. db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
  167. )
  168. # cleanup if any previous tests failed or were interrupted
  169. app1.db.reset()
  170. app2.db.reset()
  171. app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  172. assert app1.db.count() == 1
  173. assert app2.db.count() == 0
  174. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
  175. app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
  176. app1.set_collection_name("test_collection_2")
  177. assert app1.db.count() == 1
  178. app2.set_collection_name("test_collection_1")
  179. assert app2.db.count() == 3
  180. # cleanup
  181. app1.db.reset()
  182. app2.db.reset()
  183. def test_chroma_db_collection_ids_share_collections():
  184. app1 = App(
  185. AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  186. )
  187. app1.set_collection_name("one_collection")
  188. app2 = App(
  189. AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  190. )
  191. app2.set_collection_name("one_collection")
  192. app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
  193. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  194. assert app1.db.count() == 3
  195. assert app2.db.count() == 3
  196. # cleanup
  197. app1.db.reset()
  198. app2.db.reset()
  199. def test_chroma_db_collection_reset():
  200. app1 = App(
  201. AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  202. )
  203. app1.set_collection_name("one_collection")
  204. app2 = App(
  205. AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  206. )
  207. app2.set_collection_name("two_collection")
  208. app3 = App(
  209. AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  210. )
  211. app3.set_collection_name("three_collection")
  212. app4 = App(
  213. AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
  214. )
  215. app4.set_collection_name("four_collection")
  216. app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
  217. app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
  218. app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
  219. app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
  220. app1.db.reset()
  221. assert app1.db.count() == 0
  222. assert app2.db.count() == 1
  223. assert app3.db.count() == 1
  224. assert app4.db.count() == 1
  225. # cleanup
  226. app2.db.reset()
  227. app3.db.reset()
  228. app4.db.reset()