test_aws_bedrock.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import pytest
  2. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  3. from embedchain.config import BaseLlmConfig
  4. from embedchain.llm.aws_bedrock import AWSBedrockLlm
  5. @pytest.fixture
  6. def config(monkeypatch):
  7. monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
  8. monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
  9. monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
  10. config = BaseLlmConfig(
  11. model="amazon.titan-text-express-v1",
  12. model_kwargs={
  13. "temperature": 0.5,
  14. "topP": 1,
  15. "maxTokenCount": 1000,
  16. },
  17. )
  18. yield config
  19. monkeypatch.delenv("AWS_ACCESS_KEY_ID")
  20. monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
  21. monkeypatch.delenv("OPENAI_API_KEY")
  22. def test_get_llm_model_answer(config, mocker):
  23. mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
  24. llm = AWSBedrockLlm(config)
  25. answer = llm.get_llm_model_answer("Test query")
  26. assert answer == "Test answer"
  27. mocked_get_answer.assert_called_once_with("Test query", config)
  28. def test_get_llm_model_answer_empty_prompt(config, mocker):
  29. mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
  30. llm = AWSBedrockLlm(config)
  31. answer = llm.get_llm_model_answer("")
  32. assert answer == "Test answer"
  33. mocked_get_answer.assert_called_once_with("", config)
  34. def test_get_llm_model_answer_with_streaming(config, mocker):
  35. config.stream = True
  36. mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
  37. llm = AWSBedrockLlm(config)
  38. llm.get_llm_model_answer("Test query")
  39. mocked_bedrock_chat.assert_called_once()
  40. callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
  41. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)