test_llama2.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import pytest
  3. from embedchain.llm.llama2 import Llama2Llm
  4. @pytest.fixture
  5. def llama2_llm():
  6. os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
  7. llm = Llama2Llm()
  8. return llm
  9. def test_init_raises_value_error_without_api_key(mocker):
  10. mocker.patch.dict(os.environ, clear=True)
  11. with pytest.raises(ValueError):
  12. Llama2Llm()
  13. def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
  14. llama2_llm.config.system_prompt = "system_prompt"
  15. with pytest.raises(ValueError):
  16. llama2_llm.get_llm_model_answer("prompt")
  17. def test_get_llm_model_answer(llama2_llm, mocker):
  18. mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
  19. mocked_replicate_instance = mocker.MagicMock()
  20. mocked_replicate.return_value = mocked_replicate_instance
  21. mocked_replicate_instance.return_value = "Test answer"
  22. llama2_llm.config.model = "test_model"
  23. llama2_llm.config.max_tokens = 50
  24. llama2_llm.config.temperature = 0.7
  25. llama2_llm.config.top_p = 0.8
  26. answer = llama2_llm.get_llm_model_answer("Test query")
  27. assert answer == "Test answer"
  28. mocked_replicate.assert_called_once_with(
  29. model="test_model",
  30. input={
  31. "temperature": 0.7,
  32. "max_length": 50,
  33. "top_p": 0.8,
  34. },
  35. )
  36. mocked_replicate_instance.assert_called_once_with("Test query")