浏览代码

Add model_kwargs to OpenAI call (#1402)

Pranav Puranik 1 年之前
父节点
当前提交
6ecdadfd97
共有 3 个文件被更改,包括 25 次插入3 次删除
  1. 7 2
      docs/api-reference/advanced/configuration.mdx
  2. 1 1
      embedchain/llm/openai.py
  3. 17 0
      tests/llm/test_openai.py

+ 7 - 2
docs/api-reference/advanced/configuration.mdx

@@ -26,6 +26,9 @@ llm:
     top_p: 1
     stream: false
     api_key: sk-xxx
+    model_kwargs:
+      response_format: 
+        type: json_object
     prompt: |
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -83,7 +86,8 @@ cache:
       "stream": false,
       "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
       "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
-      "api_key": "sk-xxx"
+      "api_key": "sk-xxx",
+      "model_kwargs": {"response_format": {"type": "json_object"}}
     }
   },
   "vectordb": {
@@ -143,7 +147,8 @@ config = {
             'system_prompt': (
                 "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare."
             ),
-            'api_key': 'sk-xxx'
+            'api_key': 'sk-xxx',
+            "model_kwargs": {"response_format": {"type": "json_object"}}
         }
     },
     'vectordb': {

+ 1 - 1
embedchain/llm/openai.py

@@ -36,7 +36,7 @@ class OpenAILlm(BaseLlm):
             "model": config.model or "gpt-3.5-turbo",
             "temperature": config.temperature,
             "max_tokens": config.max_tokens,
-            "model_kwargs": {},
+            "model_kwargs": config.model_kwargs or {},
         }
         api_key = config.api_key or os.environ["OPENAI_API_KEY"]
         base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)

+ 17 - 0
tests/llm/test_openai.py

@@ -96,6 +96,23 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
     )
 
 
+def test_get_llm_model_answer_with_model_kwargs(config, mocker):
+    config.model_kwargs = {"response_format": {"type": "json_object"}}
+    mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+
+    llm = OpenAILlm(config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_openai_chat.assert_called_once_with(
+        model=config.model,
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
+        api_key=os.environ["OPENAI_API_KEY"],
+        base_url=os.environ["OPENAI_API_BASE"],
+    )
+
+
 @pytest.mark.parametrize(
     "mock_return, expected",
     [