test_embedchain.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. import pytest
  3. from chromadb.api.models.Collection import Collection
  4. from embedchain import App
  5. from embedchain.config import AppConfig, ChromaDbConfig
  6. from embedchain.embedchain import EmbedChain
  7. from embedchain.llm.base import BaseLlm
  8. from embedchain.memory.base import ECChatMemory
  9. os.environ["OPENAI_API_KEY"] = "test-api-key"
  10. @pytest.fixture
  11. def app_instance():
  12. config = AppConfig(log_level="DEBUG", collect_metrics=False)
  13. return App(config)
  14. def test_whole_app(app_instance, mocker):
  15. knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
  16. mocker.patch.object(EmbedChain, "add")
  17. mocker.patch.object(EmbedChain, "_retrieve_from_database")
  18. mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
  19. mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
  20. mocker.patch.object(BaseLlm, "generate_prompt")
  21. mocker.patch.object(
  22. BaseLlm,
  23. "add_history",
  24. )
  25. mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
  26. app_instance.add(knowledge, data_type="text")
  27. app_instance.query("What text did I give you?")
  28. app_instance.chat("What text did I give you?")
  29. assert BaseLlm.generate_prompt.call_count == 2
  30. app_instance.reset()
  31. def test_add_after_reset(app_instance, mocker):
  32. mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
  33. config = AppConfig(log_level="DEBUG", collect_metrics=False)
  34. chroma_config = {"allow_reset": True}
  35. app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
  36. # mock delete chat history
  37. mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
  38. app_instance.reset()
  39. app_instance.db.client.heartbeat()
  40. mocker.patch.object(Collection, "add")
  41. app_instance.db.collection.add(
  42. embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
  43. metadatas=[
  44. {"chapter": "3", "verse": "16"},
  45. {"chapter": "3", "verse": "5"},
  46. {"chapter": "29", "verse": "11"},
  47. ],
  48. ids=["id1", "id2", "id3"],
  49. )
  50. app_instance.reset()
  51. def test_add_with_incorrect_content(app_instance, mocker):
  52. content = [{"foo": "bar"}]
  53. with pytest.raises(TypeError):
  54. app_instance.add(content, data_type="json")