test_mistralai.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import pytest
  2. from embedchain.config import BaseLlmConfig
  3. from embedchain.llm.mistralai import MistralAILlm
  4. @pytest.fixture
  5. def mistralai_llm_config(monkeypatch):
  6. monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
  7. yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
  8. monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
  9. def test_mistralai_llm_init_missing_api_key(monkeypatch):
  10. monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
  11. with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
  12. MistralAILlm()
  13. def test_mistralai_llm_init(monkeypatch):
  14. monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
  15. llm = MistralAILlm()
  16. assert llm is not None
  17. def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
  18. def mock_get_answer(self, prompt, config):
  19. return "Generated Text"
  20. monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
  21. llm = MistralAILlm(config=mistralai_llm_config)
  22. result = llm.get_llm_model_answer("test prompt")
  23. assert result == "Generated Text"
  24. def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
  25. mistralai_llm_config.system_prompt = "Test system prompt"
  26. monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
  27. llm = MistralAILlm(config=mistralai_llm_config)
  28. result = llm.get_llm_model_answer("test prompt")
  29. assert result == "Generated Text"
  30. def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
  31. monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
  32. llm = MistralAILlm(config=mistralai_llm_config)
  33. result = llm.get_llm_model_answer("")
  34. assert result == "Generated Text"
  35. def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
  36. mistralai_llm_config.system_prompt = None
  37. monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
  38. llm = MistralAILlm(config=mistralai_llm_config)
  39. result = llm.get_llm_model_answer("test prompt")
  40. assert result == "Generated Text"
  41. def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
  42. test_config = BaseLlmConfig(
  43. temperature=mistralai_llm_config.temperature,
  44. max_tokens=mistralai_llm_config.max_tokens,
  45. top_p=mistralai_llm_config.top_p,
  46. model=mistralai_llm_config.model,
  47. token_usage=True,
  48. )
  49. monkeypatch.setattr(
  50. MistralAILlm,
  51. "_get_answer",
  52. lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
  53. )
  54. llm = MistralAILlm(test_config)
  55. answer, token_info = llm.get_llm_model_answer("Test query")
  56. assert answer == "Generated Text"
  57. assert token_info == {
  58. "prompt_tokens": 1,
  59. "completion_tokens": 2,
  60. "total_tokens": 3,
  61. "total_cost": 7.5e-07,
  62. "cost_currency": "USD",
  63. }