1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import pytest
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
- from embedchain.config import BaseLlmConfig
- from embedchain.llm.aws_bedrock import AWSBedrockLlm
- @pytest.fixture
- def config(monkeypatch):
- monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
- monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
- monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
- config = BaseLlmConfig(
- model="amazon.titan-text-express-v1",
- model_kwargs={
- "temperature": 0.5,
- "topP": 1,
- "maxTokenCount": 1000,
- },
- )
- yield config
- monkeypatch.delenv("AWS_ACCESS_KEY_ID")
- monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
- monkeypatch.delenv("OPENAI_API_KEY")
- def test_get_llm_model_answer(config, mocker):
- mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
- llm = AWSBedrockLlm(config)
- answer = llm.get_llm_model_answer("Test query")
- assert answer == "Test answer"
- mocked_get_answer.assert_called_once_with("Test query", config)
- def test_get_llm_model_answer_empty_prompt(config, mocker):
- mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
- llm = AWSBedrockLlm(config)
- answer = llm.get_llm_model_answer("")
- assert answer == "Test answer"
- mocked_get_answer.assert_called_once_with("", config)
- def test_get_llm_model_answer_with_streaming(config, mocker):
- config.stream = True
- mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
- llm = AWSBedrockLlm(config)
- llm.get_llm_model_answer("Test query")
- mocked_bedrock_chat.assert_called_once()
- callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
- assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|