ソースを参照

Add model_kwargs to OpenAI call (#1402)

Pranav Puranik 1 年間 前
コミット
6ecdadfd97

+ 7 - 2
docs/api-reference/advanced/configuration.mdx

@@ -26,6 +26,9 @@ llm:
     top_p: 1
     stream: false
     api_key: sk-xxx
+    model_kwargs:
+      response_format: 
+        type: json_object
     prompt: |
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -83,7 +86,8 @@ cache:
       "stream": false,
       "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
       "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
-      "api_key": "sk-xxx"
+      "api_key": "sk-xxx",
+      "model_kwargs": {"response_format": {"type": "json_object"}}
     }
   },
   "vectordb": {
@@ -143,7 +147,8 @@ config = {
             'system_prompt': (
                 "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare."
             ),
-            'api_key': 'sk-xxx'
+            'api_key': 'sk-xxx',
+            "model_kwargs": {"response_format": {"type": "json_object"}}
         }
     },
     'vectordb': {

+ 1 - 1
embedchain/llm/openai.py

@@ -36,7 +36,7 @@ class OpenAILlm(BaseLlm):
             "model": config.model or "gpt-3.5-turbo",
             "temperature": config.temperature,
             "max_tokens": config.max_tokens,
-            "model_kwargs": {},
+            "model_kwargs": config.model_kwargs or {},
         }
         api_key = config.api_key or os.environ["OPENAI_API_KEY"]
         base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)

+ 17 - 0
tests/llm/test_openai.py

@@ -96,6 +96,23 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
     )
 
 
+def test_get_llm_model_answer_with_model_kwargs(config, mocker):
+    config.model_kwargs = {"response_format": {"type": "json_object"}}
+    mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+
+    llm = OpenAILlm(config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_openai_chat.assert_called_once_with(
+        model=config.model,
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
+        api_key=os.environ["OPENAI_API_KEY"],
+        base_url=os.environ["OPENAI_API_BASE"],
+    )
+
+
 @pytest.mark.parametrize(
     "mock_return, expected",
     [