test_openai.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import os
  2. import httpx
  3. import pytest
  4. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.llm.openai import OpenAILlm
  7. @pytest.fixture()
  8. def env_config():
  9. os.environ["OPENAI_API_KEY"] = "test_api_key"
  10. os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
  11. yield
  12. os.environ.pop("OPENAI_API_KEY")
  13. @pytest.fixture
  14. def config(env_config):
  15. config = BaseLlmConfig(
  16. temperature=0.7,
  17. max_tokens=50,
  18. top_p=0.8,
  19. stream=False,
  20. system_prompt="System prompt",
  21. model="gpt-3.5-turbo",
  22. http_client_proxies=None,
  23. http_async_client_proxies=None,
  24. )
  25. yield config
  26. def test_get_llm_model_answer(config, mocker):
  27. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  28. llm = OpenAILlm(config)
  29. answer = llm.get_llm_model_answer("Test query")
  30. assert answer == "Test answer"
  31. mocked_get_answer.assert_called_once_with("Test query", config)
  32. def test_get_llm_model_answer_with_system_prompt(config, mocker):
  33. config.system_prompt = "Custom system prompt"
  34. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  35. llm = OpenAILlm(config)
  36. answer = llm.get_llm_model_answer("Test query")
  37. assert answer == "Test answer"
  38. mocked_get_answer.assert_called_once_with("Test query", config)
  39. def test_get_llm_model_answer_empty_prompt(config, mocker):
  40. mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
  41. llm = OpenAILlm(config)
  42. answer = llm.get_llm_model_answer("")
  43. assert answer == "Test answer"
  44. mocked_get_answer.assert_called_once_with("", config)
  45. def test_get_llm_model_answer_with_token_usage(config, mocker):
  46. test_config = BaseLlmConfig(
  47. temperature=config.temperature,
  48. max_tokens=config.max_tokens,
  49. top_p=config.top_p,
  50. stream=config.stream,
  51. system_prompt=config.system_prompt,
  52. model=config.model,
  53. token_usage=True,
  54. )
  55. mocked_get_answer = mocker.patch(
  56. "embedchain.llm.openai.OpenAILlm._get_answer",
  57. return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
  58. )
  59. llm = OpenAILlm(test_config)
  60. answer, token_info = llm.get_llm_model_answer("Test query")
  61. assert answer == "Test answer"
  62. assert token_info == {
  63. "prompt_tokens": 1,
  64. "completion_tokens": 2,
  65. "total_tokens": 3,
  66. "total_cost": 5.5e-06,
  67. "cost_currency": "USD",
  68. }
  69. mocked_get_answer.assert_called_once_with("Test query", test_config)
  70. def test_get_llm_model_answer_with_streaming(config, mocker):
  71. config.stream = True
  72. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  73. llm = OpenAILlm(config)
  74. llm.get_llm_model_answer("Test query")
  75. mocked_openai_chat.assert_called_once()
  76. callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
  77. assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
  78. def test_get_llm_model_answer_without_system_prompt(config, mocker):
  79. config.system_prompt = None
  80. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  81. llm = OpenAILlm(config)
  82. llm.get_llm_model_answer("Test query")
  83. mocked_openai_chat.assert_called_once_with(
  84. model=config.model,
  85. temperature=config.temperature,
  86. max_tokens=config.max_tokens,
  87. model_kwargs={"top_p": config.top_p},
  88. api_key=os.environ["OPENAI_API_KEY"],
  89. base_url=os.environ["OPENAI_API_BASE"],
  90. http_client=None,
  91. http_async_client=None,
  92. )
  93. def test_get_llm_model_answer_with_special_headers(config, mocker):
  94. config.default_headers = {"test": "test"}
  95. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  96. llm = OpenAILlm(config)
  97. llm.get_llm_model_answer("Test query")
  98. mocked_openai_chat.assert_called_once_with(
  99. model=config.model,
  100. temperature=config.temperature,
  101. max_tokens=config.max_tokens,
  102. model_kwargs={"top_p": config.top_p},
  103. api_key=os.environ["OPENAI_API_KEY"],
  104. base_url=os.environ["OPENAI_API_BASE"],
  105. default_headers={"test": "test"},
  106. http_client=None,
  107. http_async_client=None,
  108. )
  109. def test_get_llm_model_answer_with_model_kwargs(config, mocker):
  110. config.model_kwargs = {"response_format": {"type": "json_object"}}
  111. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  112. llm = OpenAILlm(config)
  113. llm.get_llm_model_answer("Test query")
  114. mocked_openai_chat.assert_called_once_with(
  115. model=config.model,
  116. temperature=config.temperature,
  117. max_tokens=config.max_tokens,
  118. model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
  119. api_key=os.environ["OPENAI_API_KEY"],
  120. base_url=os.environ["OPENAI_API_BASE"],
  121. http_client=None,
  122. http_async_client=None,
  123. )
  124. @pytest.mark.parametrize(
  125. "mock_return, expected",
  126. [
  127. ([{"test": "test"}], '{"test": "test"}'),
  128. ([], "Input could not be mapped to the function!"),
  129. ],
  130. )
  131. def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
  132. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  133. mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
  134. mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
  135. mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
  136. llm = OpenAILlm(config, tools={"test": "test"})
  137. answer = llm.get_llm_model_answer("Test query")
  138. mocked_openai_chat.assert_called_once_with(
  139. model=config.model,
  140. temperature=config.temperature,
  141. max_tokens=config.max_tokens,
  142. model_kwargs={"top_p": config.top_p},
  143. api_key=os.environ["OPENAI_API_KEY"],
  144. base_url=os.environ["OPENAI_API_BASE"],
  145. http_client=None,
  146. http_async_client=None,
  147. )
  148. mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
  149. mocked_json_output_tools_parser.assert_called_once()
  150. assert answer == expected
  151. def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
  152. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  153. mock_http_client = mocker.Mock(spec=httpx.Client)
  154. mock_http_client_instance = mocker.Mock(spec=httpx.Client)
  155. mock_http_client.return_value = mock_http_client_instance
  156. mocker.patch("httpx.Client", new=mock_http_client)
  157. config = BaseLlmConfig(
  158. temperature=0.7,
  159. max_tokens=50,
  160. top_p=0.8,
  161. stream=False,
  162. system_prompt="System prompt",
  163. model="gpt-3.5-turbo",
  164. http_client_proxies="http://testproxy.mem0.net:8000",
  165. )
  166. llm = OpenAILlm(config)
  167. llm.get_llm_model_answer("Test query")
  168. mocked_openai_chat.assert_called_once_with(
  169. model=config.model,
  170. temperature=config.temperature,
  171. max_tokens=config.max_tokens,
  172. model_kwargs={"top_p": config.top_p},
  173. api_key=os.environ["OPENAI_API_KEY"],
  174. base_url=os.environ["OPENAI_API_BASE"],
  175. http_client=mock_http_client_instance,
  176. http_async_client=None,
  177. )
  178. mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
  179. def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
  180. mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
  181. mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
  182. mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
  183. mock_http_async_client.return_value = mock_http_async_client_instance
  184. mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
  185. config = BaseLlmConfig(
  186. temperature=0.7,
  187. max_tokens=50,
  188. top_p=0.8,
  189. stream=False,
  190. system_prompt="System prompt",
  191. model="gpt-3.5-turbo",
  192. http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
  193. )
  194. llm = OpenAILlm(config)
  195. llm.get_llm_model_answer("Test query")
  196. mocked_openai_chat.assert_called_once_with(
  197. model=config.model,
  198. temperature=config.temperature,
  199. max_tokens=config.max_tokens,
  200. model_kwargs={"top_p": config.top_p},
  201. api_key=os.environ["OPENAI_API_KEY"],
  202. base_url=os.environ["OPENAI_API_BASE"],
  203. http_client=None,
  204. http_async_client=mock_http_async_client_instance,
  205. )
  206. mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})