test_openai.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import os
  2. import httpx
  3. import pytest
  4. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.llm.openai import OpenAILlm
  7. @pytest.fixture()
  8. def env_config():
  9. os.environ["OPENAI_API_KEY"] = "test_api_key"
  10. os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
  11. yield
  12. os.environ.pop("OPENAI_API_KEY")
  13. @pytest.fixture
  14. def config(env_config):
  15. config = BaseLlmConfig(
  16. temperature=0.7,
  17. max_tokens=50,
  18. top_p=0.8,
  19. stream=False,
  20. system_prompt="System prompt",
  21. model="gpt-3.5-turbo",
  22. http_client_proxies=None,
  23. http_async_client_proxies=None,
  24. )
  25. yield config
  26. def test_get_llm_model_answer(config, mocker):
  27. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  28. llm = OpenAILlm(config)
  29. answer = llm.get_llm_model_answer("Test query")
  30. assert answer == "Test answer"
  31. mocked_get_answer.assert_called_once_with("Test query", config)
  32. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  33. config.system_prompt = "Custom system prompt"
  34. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  35. llm = OpenAILlm(config)
  36. answer = llm.get_llm_model_answer("Test query")
  37. assert answer == "Test answer"
  38. mocked_get_answer.assert_called_once_with("Test query", config)
  39. def test_get_llm_model_answer_empty_prompt(config, mocker):
  40. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  41. llm = OpenAILlm(config)
  42. answer = llm.get_llm_model_answer("")
  43. assert answer == "Test answer"
  44. mocked_get_answer.assert_called_once_with("", config)
  45. def test_get_llm_model_answer_with_streaming(config, mocker):
  46. config.stream = True
  47. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  48. llm = OpenAILlm(config)
  49. llm.get_llm_model_answer("Test query")
  50. mocked_openai_chat.assert_called_once()
  51. callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
  52. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  53. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  54. config.system_prompt = None
  55. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  56. llm = OpenAILlm(config)
  57. llm.get_llm_model_answer("Test query")
  58. mocked_openai_chat.assert_called_once_with(
  59. model=config.model,
  60. temperature=config.temperature,
  61. max_tokens=config.max_tokens,
  62. model_kwargs={"top_p": config.top_p},
  63. api_key=os.environ["OPENAI_API_KEY"],
  64. base_url=os.environ["OPENAI_API_BASE"],
  65. http_client=None,
  66. http_async_client=None,
  67. )
  68. def test_get_llm_model_answer_with_special_headers(config, mocker):
  69. config.default_headers = {"test": "test"}
  70. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  71. llm = OpenAILlm(config)
  72. llm.get_llm_model_answer("Test query")
  73. mocked_openai_chat.assert_called_once_with(
  74. model=config.model,
  75. temperature=config.temperature,
  76. max_tokens=config.max_tokens,
  77. model_kwargs={"top_p": config.top_p},
  78. api_key=os.environ["OPENAI_API_KEY"],
  79. base_url=os.environ["OPENAI_API_BASE"],
  80. default_headers={"test": "test"},
  81. http_client=None,
  82. http_async_client=None,
  83. )
  84. def test_get_llm_model_answer_with_model_kwargs(config, mocker):
  85. config.model_kwargs = {"response_format": {"type": "json_object"}}
  86. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  87. llm = OpenAILlm(config)
  88. llm.get_llm_model_answer("Test query")
  89. mocked_openai_chat.assert_called_once_with(
  90. model=config.model,
  91. temperature=config.temperature,
  92. max_tokens=config.max_tokens,
  93. model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
  94. api_key=os.environ["OPENAI_API_KEY"],
  95. base_url=os.environ["OPENAI_API_BASE"],
  96. http_client=None,
  97. http_async_client=None,
  98. )
  99. @pytest.mark.parametrize(
  100. "mock_return, expected",
  101. [
  102. ([{"test": "test"}], '{"test": "test"}'),
  103. ([], "Input could not be mapped to the function!"),
  104. ],
  105. )
  106. def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
  107. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  108. mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
  109. mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
  110. mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
  111. llm = OpenAILlm(config, tools={"test": "test"})
  112. answer = llm.get_llm_model_answer("Test query")
  113. mocked_openai_chat.assert_called_once_with(
  114. model=config.model,
  115. temperature=config.temperature,
  116. max_tokens=config.max_tokens,
  117. model_kwargs={"top_p": config.top_p},
  118. api_key=os.environ["OPENAI_API_KEY"],
  119. base_url=os.environ["OPENAI_API_BASE"],
  120. http_client=None,
  121. http_async_client=None,
  122. )
  123. mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
  124. mocked_json_output_tools_parser.assert_called_once()
  125. assert answer == expected
  126. def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
  127. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  128. mock_http_client = mocker.Mock(spec=httpx.Client)
  129. mock_http_client_instance = mocker.Mock(spec=httpx.Client)
  130. mock_http_client.return_value = mock_http_client_instance
  131. mocker.patch("httpx.Client", new=mock_http_client)
  132. config = BaseLlmConfig(
  133. temperature=0.7,
  134. max_tokens=50,
  135. top_p=0.8,
  136. stream=False,
  137. system_prompt="System prompt",
  138. model="gpt-3.5-turbo",
  139. http_client_proxies="http://testproxy.mem0.net:8000",
  140. )
  141. llm = OpenAILlm(config)
  142. llm.get_llm_model_answer("Test query")
  143. mocked_openai_chat.assert_called_once_with(
  144. model=config.model,
  145. temperature=config.temperature,
  146. max_tokens=config.max_tokens,
  147. model_kwargs={"top_p": config.top_p},
  148. api_key=os.environ["OPENAI_API_KEY"],
  149. base_url=os.environ["OPENAI_API_BASE"],
  150. http_client=mock_http_client_instance,
  151. http_async_client=None,
  152. )
  153. mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
  154. def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
  155. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  156. mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
  157. mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
  158. mock_http_async_client.return_value = mock_http_async_client_instance
  159. mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
  160. config = BaseLlmConfig(
  161. temperature=0.7,
  162. max_tokens=50,
  163. top_p=0.8,
  164. stream=False,
  165. system_prompt="System prompt",
  166. model="gpt-3.5-turbo",
  167. http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
  168. )
  169. llm = OpenAILlm(config)
  170. llm.get_llm_model_answer("Test query")
  171. mocked_openai_chat.assert_called_once_with(
  172. model=config.model,
  173. temperature=config.temperature,
  174. max_tokens=config.max_tokens,
  175. model_kwargs={"top_p": config.top_p},
  176. api_key=os.environ["OPENAI_API_KEY"],
  177. base_url=os.environ["OPENAI_API_BASE"],
  178. http_client=None,
  179. http_async_client=mock_http_async_client_instance,
  180. )
  181. mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})