test_openai.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. import pytest
  3. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.llm.openai import OpenAILlm
  6. @pytest.fixture
  7. def config():
  8. os.environ["OPENAI_API_KEY"] = "test_api_key"
  9. config = BaseLlmConfig(
  10. temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
  11. )
  12. yield config
  13. os.environ.pop("OPENAI_API_KEY")
  14. def test_get_llm_model_answer(config, mocker):
  15. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  16. llm = OpenAILlm(config)
  17. answer = llm.get_llm_model_answer("Test query")
  18. assert answer == "Test answer"
  19. mocked_get_answer.assert_called_once_with("Test query", config)
  20. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  21. config.system_prompt = "Custom system prompt"
  22. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  23. llm = OpenAILlm(config)
  24. answer = llm.get_llm_model_answer("Test query")
  25. assert answer == "Test answer"
  26. mocked_get_answer.assert_called_once_with("Test query", config)
  27. def test_get_llm_model_answer_empty_prompt(config, mocker):
  28. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  29. llm = OpenAILlm(config)
  30. answer = llm.get_llm_model_answer("")
  31. assert answer == "Test answer"
  32. mocked_get_answer.assert_called_once_with("", config)
  33. def test_get_llm_model_answer_with_streaming(config, mocker):
  34. config.stream = True
  35. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  36. llm = OpenAILlm(config)
  37. llm.get_llm_model_answer("Test query")
  38. mocked_openai_chat.assert_called_once()
  39. callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
  40. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  41. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  42. config.system_prompt = None
  43. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  44. llm = OpenAILlm(config)
  45. llm.get_llm_model_answer("Test query")
  46. mocked_openai_chat.assert_called_once_with(
  47. model=config.model,
  48. temperature=config.temperature,
  49. max_tokens=config.max_tokens,
  50. model_kwargs={"top_p": config.top_p},
  51. api_key=os.environ["OPENAI_API_KEY"],
  52. base_url=os.environ["OPENAI_API_BASE"],
  53. )
  54. def test_get_llm_model_answer_with_special_headers(config, mocker):
  55. config.default_headers = {'test': 'test'}
  56. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  57. llm = OpenAILlm(config)
  58. llm.get_llm_model_answer("Test query")
  59. mocked_openai_chat.assert_called_once_with(
  60. model=config.model,
  61. temperature=config.temperature,
  62. max_tokens=config.max_tokens,
  63. model_kwargs={"top_p": config.top_p},
  64. api_key=os.environ["OPENAI_API_KEY"],
  65. base_url=os.environ["OPENAI_API_BASE"],
  66. default_headers={'test': 'test'}
  67. )
  68. @pytest.mark.parametrize(
  69. "mock_return, expected",
  70. [
  71. ([{"test": "test"}], '{"test": "test"}'),
  72. ([], "Input could not be mapped to the function!"),
  73. ],
  74. )
  75. def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
  76. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  77. mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
  78. mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
  79. mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
  80. llm = OpenAILlm(config, tools={"test": "test"})
  81. answer = llm.get_llm_model_answer("Test query")
  82. mocked_openai_chat.assert_called_once_with(
  83. model=config.model,
  84. temperature=config.temperature,
  85. max_tokens=config.max_tokens,
  86. model_kwargs={"top_p": config.top_p},
  87. api_key=os.environ["OPENAI_API_KEY"],
  88. base_url=os.environ["OPENAI_API_BASE"],
  89. )
  90. mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
  91. mocked_json_output_tools_parser.assert_called_once()
  92. assert answer == expected