test_openai.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. import pytest
  3. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.llm.openai import OpenAILlm
  6. @pytest.fixture
  7. def config():
  8. os.environ["OPENAI_API_KEY"] = "test_api_key"
  9. config = BaseLlmConfig(
  10. temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
  11. )
  12. yield config
  13. os.environ.pop("OPENAI_API_KEY")
  14. def test_get_llm_model_answer(config, mocker):
  15. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  16. llm = OpenAILlm(config)
  17. answer = llm.get_llm_model_answer("Test query")
  18. assert answer == "Test answer"
  19. mocked_get_answer.assert_called_once_with("Test query", config)
  20. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  21. config.system_prompt = "Custom system prompt"
  22. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  23. llm = OpenAILlm(config)
  24. answer = llm.get_llm_model_answer("Test query")
  25. assert answer == "Test answer"
  26. mocked_get_answer.assert_called_once_with("Test query", config)
  27. def test_get_llm_model_answer_empty_prompt(config, mocker):
  28. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  29. llm = OpenAILlm(config)
  30. answer = llm.get_llm_model_answer("")
  31. assert answer == "Test answer"
  32. mocked_get_answer.assert_called_once_with("", config)
  33. def test_get_llm_model_answer_with_streaming(config, mocker):
  34. config.stream = True
  35. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  36. llm = OpenAILlm(config)
  37. llm.get_llm_model_answer("Test query")
  38. mocked_openai_chat.assert_called_once()
  39. callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
  40. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  41. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  42. config.system_prompt = None
  43. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  44. llm = OpenAILlm(config)
  45. llm.get_llm_model_answer("Test query")
  46. mocked_openai_chat.assert_called_once_with(
  47. model=config.model,
  48. temperature=config.temperature,
  49. max_tokens=config.max_tokens,
  50. model_kwargs={"top_p": config.top_p},
  51. api_key=os.environ["OPENAI_API_KEY"],
  52. base_url=os.environ["OPENAI_API_BASE"],
  53. )
  54. @pytest.mark.parametrize(
  55. "mock_return, expected",
  56. [
  57. ([{"test": "test"}], '{"test": "test"}'),
  58. ([], "Input could not be mapped to the function!"),
  59. ],
  60. )
  61. def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
  62. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  63. mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
  64. mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
  65. mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
  66. llm = OpenAILlm(config, tools={"test": "test"})
  67. answer = llm.get_llm_model_answer("Test query")
  68. mocked_openai_chat.assert_called_once_with(
  69. model=config.model,
  70. temperature=config.temperature,
  71. max_tokens=config.max_tokens,
  72. model_kwargs={"top_p": config.top_p},
  73. api_key=os.environ["OPENAI_API_KEY"],
  74. base_url=os.environ["OPENAI_API_BASE"],
  75. )
  76. mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
  77. mocked_json_output_tools_parser.assert_called_once()
  78. assert answer == expected