test_openai.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import os
  2. import pytest
  3. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.llm.openai import OpenAILlm
  6. @pytest.fixture
  7. def config():
  8. os.environ["OPENAI_API_KEY"] = "test_api_key"
  9. os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
  10. config = BaseLlmConfig(
  11. temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
  12. )
  13. yield config
  14. os.environ.pop("OPENAI_API_KEY")
  15. def test_get_llm_model_answer(config, mocker):
  16. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  17. llm = OpenAILlm(config)
  18. answer = llm.get_llm_model_answer("Test query")
  19. assert answer == "Test answer"
  20. mocked_get_answer.assert_called_once_with("Test query", config)
  21. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  22. config.system_prompt = "Custom system prompt"
  23. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  24. llm = OpenAILlm(config)
  25. answer = llm.get_llm_model_answer("Test query")
  26. assert answer == "Test answer"
  27. mocked_get_answer.assert_called_once_with("Test query", config)
  28. def test_get_llm_model_answer_empty_prompt(config, mocker):
  29. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  30. llm = OpenAILlm(config)
  31. answer = llm.get_llm_model_answer("")
  32. assert answer == "Test answer"
  33. mocked_get_answer.assert_called_once_with("", config)
  34. def test_get_llm_model_answer_with_streaming(config, mocker):
  35. config.stream = True
  36. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  37. llm = OpenAILlm(config)
  38. llm.get_llm_model_answer("Test query")
  39. mocked_openai_chat.assert_called_once()
  40. callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
  41. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  42. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  43. config.system_prompt = None
  44. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  45. llm = OpenAILlm(config)
  46. llm.get_llm_model_answer("Test query")
  47. mocked_openai_chat.assert_called_once_with(
  48. model=config.model,
  49. temperature=config.temperature,
  50. max_tokens=config.max_tokens,
  51. model_kwargs={"top_p": config.top_p},
  52. api_key=os.environ["OPENAI_API_KEY"],
  53. base_url=os.environ["OPENAI_API_BASE"],
  54. )
  55. def test_get_llm_model_answer_with_special_headers(config, mocker):
  56. config.default_headers = {"test": "test"}
  57. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  58. llm = OpenAILlm(config)
  59. llm.get_llm_model_answer("Test query")
  60. mocked_openai_chat.assert_called_once_with(
  61. model=config.model,
  62. temperature=config.temperature,
  63. max_tokens=config.max_tokens,
  64. model_kwargs={"top_p": config.top_p},
  65. api_key=os.environ["OPENAI_API_KEY"],
  66. base_url=os.environ["OPENAI_API_BASE"],
  67. default_headers={"test": "test"},
  68. )
  69. @pytest.mark.parametrize(
  70. "mock_return, expected",
  71. [
  72. ([{"test": "test"}], '{"test": "test"}'),
  73. ([], "Input could not be mapped to the function!"),
  74. ],
  75. )
  76. def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
  77. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  78. mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
  79. mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
  80. mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
  81. llm = OpenAILlm(config, tools={"test": "test"})
  82. answer = llm.get_llm_model_answer("Test query")
  83. mocked_openai_chat.assert_called_once_with(
  84. model=config.model,
  85. temperature=config.temperature,
  86. max_tokens=config.max_tokens,
  87. model_kwargs={"top_p": config.top_p},
  88. api_key=os.environ["OPENAI_API_KEY"],
  89. base_url=os.environ["OPENAI_API_BASE"],
  90. )
  91. mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
  92. mocked_json_output_tools_parser.assert_called_once()
  93. assert answer == expected