12345678910111213141516171819202122232425262728293031323334353637383940 |
- import os
- import pytest
- from embedchain.llm.llama2 import Llama2Llm
- @pytest.fixture
- def llama2_llm():
- os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
- llm = Llama2Llm()
- return llm
- def test_init_raises_value_error_without_api_key(mocker):
- mocker.patch.dict(os.environ, clear=True)
- with pytest.raises(ValueError):
- Llama2Llm()
- def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
- llama2_llm.config.system_prompt = "system_prompt"
- with pytest.raises(ValueError):
- llama2_llm.get_llm_model_answer("prompt")
- def test_get_llm_model_answer(llama2_llm, mocker):
- mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
- mocked_replicate_instance = mocker.MagicMock()
- mocked_replicate.return_value = mocked_replicate_instance
- mocked_replicate_instance.invoke.return_value = "Test answer"
- llama2_llm.config.model = "test_model"
- llama2_llm.config.max_tokens = 50
- llama2_llm.config.temperature = 0.7
- llama2_llm.config.top_p = 0.8
- answer = llama2_llm.get_llm_model_answer("Test query")
- assert answer == "Test answer"
|