test_embedchain.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import pytest
  3. from chromadb.api.models.Collection import Collection
  4. from embedchain import App
  5. from embedchain.config import AppConfig, ChromaDbConfig
  6. from embedchain.embedchain import EmbedChain
  7. from embedchain.llm.base import BaseLlm
  8. from embedchain.memory.base import ChatHistory
  9. from embedchain.vectordb.chroma import ChromaDB
  10. os.environ["OPENAI_API_KEY"] = "test-api-key"
  11. @pytest.fixture
  12. def app_instance():
  13. config = AppConfig(log_level="DEBUG", collect_metrics=False)
  14. return App(config=config)
  15. def test_whole_app(app_instance, mocker):
  16. knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
  17. mocker.patch.object(EmbedChain, "add")
  18. mocker.patch.object(EmbedChain, "_retrieve_from_database")
  19. mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
  20. mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
  21. mocker.patch.object(BaseLlm, "generate_prompt")
  22. mocker.patch.object(
  23. BaseLlm,
  24. "add_history",
  25. )
  26. mocker.patch.object(ChatHistory, "delete", autospec=True)
  27. app_instance.add(knowledge, data_type="text")
  28. app_instance.query("What text did I give you?")
  29. app_instance.chat("What text did I give you?")
  30. assert BaseLlm.generate_prompt.call_count == 2
  31. app_instance.reset()
  32. def test_add_after_reset(app_instance, mocker):
  33. mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
  34. config = AppConfig(log_level="DEBUG", collect_metrics=False)
  35. chroma_config = ChromaDbConfig(allow_reset=True)
  36. db = ChromaDB(config=chroma_config)
  37. app_instance = App(config=config, db=db)
  38. # mock delete chat history
  39. mocker.patch.object(ChatHistory, "delete", autospec=True)
  40. app_instance.reset()
  41. app_instance.db.client.heartbeat()
  42. mocker.patch.object(Collection, "add")
  43. app_instance.db.collection.add(
  44. embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
  45. metadatas=[
  46. {"chapter": "3", "verse": "16"},
  47. {"chapter": "3", "verse": "5"},
  48. {"chapter": "29", "verse": "11"},
  49. ],
  50. ids=["id1", "id2", "id3"],
  51. )
  52. app_instance.reset()
  53. def test_add_with_incorrect_content(app_instance, mocker):
  54. content = [{"foo": "bar"}]
  55. with pytest.raises(TypeError):
  56. app_instance.add(content, data_type="json")