|
@@ -96,6 +96,23 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
|
|
|
)
|
|
|
|
|
|
|
|
|
+def test_get_llm_model_answer_with_model_kwargs(config, mocker):
|
|
|
+ config.model_kwargs = {"response_format": {"type": "json_object"}}
|
|
|
+ mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
|
|
+
|
|
|
+ llm = OpenAILlm(config)
|
|
|
+ llm.get_llm_model_answer("Test query")
|
|
|
+
|
|
|
+ mocked_openai_chat.assert_called_once_with(
|
|
|
+ model=config.model,
|
|
|
+ temperature=config.temperature,
|
|
|
+ max_tokens=config.max_tokens,
|
|
|
+ model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
|
|
|
+ api_key=os.environ["OPENAI_API_KEY"],
|
|
|
+ base_url=os.environ["OPENAI_API_BASE"],
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.parametrize(
|
|
|
"mock_return, expected",
|
|
|
[
|