|
@@ -28,7 +28,7 @@ def test_get_llm_model_answer(llama2_llm, mocker):
|
|
|
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
|
|
|
mocked_replicate_instance = mocker.MagicMock()
|
|
|
mocked_replicate.return_value = mocked_replicate_instance
|
|
|
- mocked_replicate_instance.return_value = "Test answer"
|
|
|
+ mocked_replicate_instance.invoke.return_value = "Test answer"
|
|
|
|
|
|
llama2_llm.config.model = "test_model"
|
|
|
llama2_llm.config.max_tokens = 50
|
|
@@ -38,12 +38,3 @@ def test_get_llm_model_answer(llama2_llm, mocker):
|
|
|
answer = llama2_llm.get_llm_model_answer("Test query")
|
|
|
|
|
|
assert answer == "Test answer"
|
|
|
- mocked_replicate.assert_called_once_with(
|
|
|
- model="test_model",
|
|
|
- input={
|
|
|
- "temperature": 0.7,
|
|
|
- "max_length": 50,
|
|
|
- "top_p": 0.8,
|
|
|
- },
|
|
|
- )
|
|
|
- mocked_replicate_instance.assert_called_once_with("Test query")
|