Przeglądaj źródła

azure openai features and bugs solve - openai_version, docs (#1425)

Pranav Puranik 1 rok temu
rodzic
commit
4547d870af

+ 3 - 1
docs/api-reference/advanced/configuration.mdx

@@ -29,6 +29,7 @@ llm:
     model_kwargs:
       response_format: 
         type: json_object
+    api_version: 2024-02-01
     prompt: |
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -87,7 +88,8 @@ cache:
       "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
       "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
       "api_key": "sk-xxx",
-      "model_kwargs": {"response_format": {"type": "json_object"}}
+      "model_kwargs": {"response_format": {"type": "json_object"}},
+      "api_version": "2024-02-01"
     }
   },
   "vectordb": {

+ 2 - 2
docs/components/llms.mdx

@@ -193,8 +193,8 @@ import os
 from embedchain import App
 
 os.environ["OPENAI_API_TYPE"] = "azure"
-os.environ["OPENAI_API_BASE"] = "https://xxx.openai.azure.com/"
-os.environ["OPENAI_API_KEY"] = "xxx"
+os.environ["AZURE_OPENAI_ENDPOINT"] = "https://xxx.openai.azure.com/"
+os.environ["AZURE_OPENAI_KEY"] = "xxx"
 os.environ["OPENAI_API_VERSION"] = "xxx"
 
 app = App.from_config(config_path="config.yaml")

+ 2 - 0
embedchain/config/llm/base.py

@@ -103,6 +103,7 @@ class BaseLlmConfig(BaseConfig):
         http_async_client: Optional[Any] = None,
         local: Optional[bool] = False,
         default_headers: Optional[Mapping[str, str]] = None,
+        api_version: Optional[str] = None,
     ):
         """
         Initializes a configuration class instance for the LLM.
@@ -185,6 +186,7 @@ class BaseLlmConfig(BaseConfig):
         self.local = local
         self.default_headers = default_headers
         self.online = online
+        self.api_version = api_version
 
         if isinstance(prompt, str):
             prompt = Template(prompt)

+ 1 - 1
embedchain/llm/azure_openai.py

@@ -25,7 +25,7 @@ class AzureOpenAILlm(BaseLlm):
 
         chat = AzureChatOpenAI(
             deployment_name=config.deployment_name,
-            openai_api_version="2023-05-15",
+            openai_api_version=str(config.api_version) if config.api_version else "2023-05-15",
             model_name=config.model or "gpt-3.5-turbo",
             temperature=config.temperature,
             max_tokens=config.max_tokens,

+ 2 - 0
embedchain/utils/misc.py

@@ -1,3 +1,4 @@
+import datetime
 import itertools
 import json
 import logging
@@ -439,6 +440,7 @@ def validate_config(config_data):
                     Optional("local"): bool,
                     Optional("base_url"): str,
                     Optional("default_headers"): dict,
+                    Optional("api_version"): Or(str, datetime.date)
                 },
             },
             Optional("vectordb"): {

+ 24 - 0
tests/llm/test_azure_openai.py

@@ -64,3 +64,27 @@ def test_when_no_deployment_name_provided():
     with pytest.raises(ValueError):
         llm = AzureOpenAILlm(config)
         llm.get_llm_model_answer("Test Prompt")
+
+def test_with_api_version():
+    config = BaseLlmConfig(
+        deployment_name="azure_deployment",
+        temperature=0.7,
+        model="gpt-3.5-turbo",
+        max_tokens=50,
+        system_prompt="System Prompt",
+        api_version="2024-02-01",
+    )
+
+    with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:
+
+        llm = AzureOpenAILlm(config)
+        llm.get_llm_model_answer("Test Prompt")
+
+        mock_chat.assert_called_once_with(
+            deployment_name="azure_deployment",
+            openai_api_version="2024-02-01",
+            model_name="gpt-3.5-turbo",
+            temperature=0.7,
+            max_tokens=50,
+            streaming=False,
+        )