فهرست منبع

Add model_kwargs to OpenAI call (#1402)

Pranav Puranik 1 سال پیش
والد
کامیت
6ecdadfd97
3فایلهای تغییر یافته به همراه25 افزوده شده و 3 حذف شده
  1. 7 2
      docs/api-reference/advanced/configuration.mdx
  2. 1 1
      embedchain/llm/openai.py
  3. 17 0
      tests/llm/test_openai.py

+ 7 - 2
docs/api-reference/advanced/configuration.mdx

@@ -26,6 +26,9 @@ llm:
     top_p: 1
     stream: false
     api_key: sk-xxx
+    model_kwargs:
+      response_format: 
+        type: json_object
     prompt: |
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -83,7 +86,8 @@ cache:
       "stream": false,
       "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
       "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
-      "api_key": "sk-xxx"
+      "api_key": "sk-xxx",
+      "model_kwargs": {"response_format": {"type": "json_object"}}
     }
   },
   "vectordb": {
@@ -143,7 +147,8 @@ config = {
             'system_prompt': (
                 "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare."
             ),
-            'api_key': 'sk-xxx'
+            'api_key': 'sk-xxx',
+            "model_kwargs": {"response_format": {"type": "json_object"}}
         }
     },
     'vectordb': {

+ 1 - 1
embedchain/llm/openai.py

@@ -36,7 +36,7 @@ class OpenAILlm(BaseLlm):
             "model": config.model or "gpt-3.5-turbo",
             "temperature": config.temperature,
             "max_tokens": config.max_tokens,
-            "model_kwargs": {},
+            "model_kwargs": config.model_kwargs or {},
         }
         api_key = config.api_key or os.environ["OPENAI_API_KEY"]
         base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)

+ 17 - 0
tests/llm/test_openai.py

@@ -96,6 +96,23 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
     )
 
 
+def test_get_llm_model_answer_with_model_kwargs(config, mocker):
+    config.model_kwargs = {"response_format": {"type": "json_object"}}
+    mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+
+    llm = OpenAILlm(config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_openai_chat.assert_called_once_with(
+        model=config.model,
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
+        api_key=os.environ["OPENAI_API_KEY"],
+        base_url=os.environ["OPENAI_API_BASE"],
+    )
+
+
 @pytest.mark.parametrize(
     "mock_return, expected",
     [