test_embedchain.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. import pytest
  3. from chromadb.api.models.Collection import Collection
  4. from embedchain import App
  5. from embedchain.config import AppConfig, ChromaDbConfig
  6. from embedchain.embedchain import EmbedChain
  7. from embedchain.llm.base import BaseLlm
  8. os.environ["OPENAI_API_KEY"] = "test-api-key"
  9. @pytest.fixture
  10. def app_instance():
  11. config = AppConfig(log_level="DEBUG", collect_metrics=False)
  12. return App(config)
  13. def test_whole_app(app_instance, mocker):
  14. knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
  15. mocker.patch.object(EmbedChain, "add")
  16. mocker.patch.object(EmbedChain, "retrieve_from_database")
  17. mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
  18. mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
  19. mocker.patch.object(BaseLlm, "generate_prompt")
  20. app_instance.add(knowledge, data_type="text")
  21. app_instance.query("What text did I give you?")
  22. app_instance.chat("What text did I give you?")
  23. assert BaseLlm.generate_prompt.call_count == 2
  24. app_instance.reset()
  25. def test_add_after_reset(app_instance, mocker):
  26. mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
  27. config = AppConfig(log_level="DEBUG", collect_metrics=False)
  28. chroma_config = {"allow_reset": True}
  29. app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
  30. app_instance.reset()
  31. app_instance.db.client.heartbeat()
  32. mocker.patch.object(Collection, "add")
  33. app_instance.db.collection.add(
  34. embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
  35. metadatas=[
  36. {"chapter": "3", "verse": "16"},
  37. {"chapter": "3", "verse": "5"},
  38. {"chapter": "29", "verse": "11"},
  39. ],
  40. ids=["id1", "id2", "id3"],
  41. )
  42. app_instance.reset()
  43. def test_add_with_incorrect_content(app_instance, mocker):
  44. content = [{"foo": "bar"}]
  45. with pytest.raises(ValueError):
  46. app_instance.add(content, data_type="json")