test_openai.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import os
  2. import pytest
  3. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.llm.openai import OpenAILlm
  6. @pytest.fixture
  7. def config():
  8. os.environ["OPENAI_API_KEY"] = "test_api_key"
  9. config = BaseLlmConfig(
  10. temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
  11. )
  12. yield config
  13. os.environ.pop("OPENAI_API_KEY")
  14. def test_get_llm_model_answer(config, mocker):
  15. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  16. llm = OpenAILlm(config)
  17. answer = llm.get_llm_model_answer("Test query")
  18. assert answer == "Test answer"
  19. mocked_get_answer.assert_called_once_with("Test query", config)
  20. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  21. config.system_prompt = "Custom system prompt"
  22. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  23. llm = OpenAILlm(config)
  24. answer = llm.get_llm_model_answer("Test query")
  25. assert answer == "Test answer"
  26. mocked_get_answer.assert_called_once_with("Test query", config)
  27. def test_get_llm_model_answer_empty_prompt(config, mocker):
  28. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  29. llm = OpenAILlm(config)
  30. answer = llm.get_llm_model_answer("")
  31. assert answer == "Test answer"
  32. mocked_get_answer.assert_called_once_with("", config)
  33. def test_get_llm_model_answer_with_streaming(config, mocker):
  34. config.stream = True
  35. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  36. llm = OpenAILlm(config)
  37. llm.get_llm_model_answer("Test query")
  38. mocked_openai_chat.assert_called_once()
  39. callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
  40. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  41. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  42. config.system_prompt = None
  43. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  44. llm = OpenAILlm(config)
  45. llm.get_llm_model_answer("Test query")
  46. mocked_openai_chat.assert_called_once_with(
  47. model=config.model,
  48. temperature=config.temperature,
  49. max_tokens=config.max_tokens,
  50. model_kwargs={"top_p": config.top_p},
  51. api_key=os.environ["OPENAI_API_KEY"],
  52. )
  53. @pytest.mark.parametrize(
  54. "mock_return, expected",
  55. [
  56. ([{"test": "test"}], '{"test": "test"}'),
  57. ([], "Input could not be mapped to the function!"),
  58. ],
  59. )
  60. def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
  61. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  62. mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
  63. mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
  64. mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
  65. llm = OpenAILlm(config, tools={"test": "test"})
  66. answer = llm.get_llm_model_answer("Test query")
  67. mocked_openai_chat.assert_called_once_with(
  68. model=config.model,
  69. temperature=config.temperature,
  70. max_tokens=config.max_tokens,
  71. model_kwargs={"top_p": config.top_p},
  72. api_key=os.environ["OPENAI_API_KEY"],
  73. )
  74. mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
  75. mocked_json_output_tools_parser.assert_called_once()
  76. assert answer == expected