test_mistralai.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import pytest
  2. from embedchain.config import BaseLlmConfig
  3. from embedchain.llm.mistralai import MistralAILlm
  4. @pytest.fixture
  5. def mistralai_llm_config(monkeypatch):
  6. monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
  7. yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
  8. monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
  9. def test_mistralai_llm_init_missing_api_key(monkeypatch):
  10. monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
  11. with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
  12. MistralAILlm()
  13. def test_mistralai_llm_init(monkeypatch):
  14. monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
  15. llm = MistralAILlm()
  16. assert llm is not None
  17. def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
  18. def mock_get_answer(prompt, config):
  19. return "Generated Text"
  20. monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
  21. llm = MistralAILlm(config=mistralai_llm_config)
  22. result = llm.get_llm_model_answer("test prompt")
  23. assert result == "Generated Text"
  24. def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
  25. mistralai_llm_config.system_prompt = "Test system prompt"
  26. monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
  27. llm = MistralAILlm(config=mistralai_llm_config)
  28. result = llm.get_llm_model_answer("test prompt")
  29. assert result == "Generated Text"
  30. def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
  31. monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
  32. llm = MistralAILlm(config=mistralai_llm_config)
  33. result = llm.get_llm_model_answer("")
  34. assert result == "Generated Text"
  35. def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
  36. mistralai_llm_config.system_prompt = None
  37. monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
  38. llm = MistralAILlm(config=mistralai_llm_config)
  39. result = llm.get_llm_model_answer("test prompt")
  40. assert result == "Generated Text"