test_openai.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import pytest
  3. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.llm.openai import OpenAILlm
  6. @pytest.fixture
  7. def config():
  8. os.environ["OPENAI_API_KEY"] = "test_api_key"
  9. os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
  10. config = BaseLlmConfig(
  11. temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
  12. )
  13. yield config
  14. os.environ.pop("OPENAI_API_KEY")
  15. def test_get_llm_model_answer(config, mocker):
  16. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  17. llm = OpenAILlm(config)
  18. answer = llm.get_llm_model_answer("Test query")
  19. assert answer == "Test answer"
  20. mocked_get_answer.assert_called_once_with("Test query", config)
  21. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  22. config.system_prompt = "Custom system prompt"
  23. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  24. llm = OpenAILlm(config)
  25. answer = llm.get_llm_model_answer("Test query")
  26. assert answer == "Test answer"
  27. mocked_get_answer.assert_called_once_with("Test query", config)
  28. def test_get_llm_model_answer_empty_prompt(config, mocker):
  29. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  30. llm = OpenAILlm(config)
  31. answer = llm.get_llm_model_answer("")
  32. assert answer == "Test answer"
  33. mocked_get_answer.assert_called_once_with("", config)
  34. def test_get_llm_model_answer_with_streaming(config, mocker):
  35. config.stream = True
  36. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  37. llm = OpenAILlm(config)
  38. llm.get_llm_model_answer("Test query")
  39. mocked_openai_chat.assert_called_once()
  40. callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
  41. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  42. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  43. config.system_prompt = None
  44. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  45. llm = OpenAILlm(config)
  46. llm.get_llm_model_answer("Test query")
  47. mocked_openai_chat.assert_called_once_with(
  48. model=config.model,
  49. temperature=config.temperature,
  50. max_tokens=config.max_tokens,
  51. model_kwargs={"top_p": config.top_p},
  52. api_key=os.environ["OPENAI_API_KEY"],
  53. base_url=os.environ["OPENAI_API_BASE"],
  54. )
  55. def test_get_llm_model_answer_with_special_headers(config, mocker):
  56. config.default_headers = {"test": "test"}
  57. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  58. llm = OpenAILlm(config)
  59. llm.get_llm_model_answer("Test query")
  60. mocked_openai_chat.assert_called_once_with(
  61. model=config.model,
  62. temperature=config.temperature,
  63. max_tokens=config.max_tokens,
  64. model_kwargs={"top_p": config.top_p},
  65. api_key=os.environ["OPENAI_API_KEY"],
  66. base_url=os.environ["OPENAI_API_BASE"],
  67. default_headers={"test": "test"},
  68. )
  69. def test_get_llm_model_answer_with_model_kwargs(config, mocker):
  70. config.model_kwargs = {"response_format": {"type": "json_object"}}
  71. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  72. llm = OpenAILlm(config)
  73. llm.get_llm_model_answer("Test query")
  74. mocked_openai_chat.assert_called_once_with(
  75. model=config.model,
  76. temperature=config.temperature,
  77. max_tokens=config.max_tokens,
  78. model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
  79. api_key=os.environ["OPENAI_API_KEY"],
  80. base_url=os.environ["OPENAI_API_BASE"],
  81. )
  82. @pytest.mark.parametrize(
  83. "mock_return, expected",
  84. [
  85. ([{"test": "test"}], '{"test": "test"}'),
  86. ([], "Input could not be mapped to the function!"),
  87. ],
  88. )
  89. def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
  90. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  91. mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
  92. mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
  93. mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
  94. llm = OpenAILlm(config, tools={"test": "test"})
  95. answer = llm.get_llm_model_answer("Test query")
  96. mocked_openai_chat.assert_called_once_with(
  97. model=config.model,
  98. temperature=config.temperature,
  99. max_tokens=config.max_tokens,
  100. model_kwargs={"top_p": config.top_p},
  101. api_key=os.environ["OPENAI_API_KEY"],
  102. base_url=os.environ["OPENAI_API_BASE"],
  103. )
  104. mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
  105. mocked_json_output_tools_parser.assert_called_once()
  106. assert answer == expected