Browse Source

Support supplying custom headers to OpenAI requests (#1356)

Niv Hertz 1 year ago
parent
commit
797dea1dca
4 changed files with 26 additions and 3 deletions
  1. 5 1
      embedchain/config/llm/base.py
  2. 3 2
      embedchain/llm/openai.py
  3. 1 0
      embedchain/utils/misc.py
  4. 17 0
      tests/llm/test_openai.py

+ 5 - 1
embedchain/config/llm/base.py

@@ -1,7 +1,7 @@
 import logging
 import re
 from string import Template
-from typing import Any, Optional
+from typing import Any, Mapping, Optional
 
 from embedchain.config.base_config import BaseConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -99,6 +99,7 @@ class BaseLlmConfig(BaseConfig):
         endpoint: Optional[str] = None,
         model_kwargs: Optional[dict[str, Any]] = None,
         local: Optional[bool] = False,
+        default_headers: Optional[Mapping[str, str]] = None,
     ):
         """
         Initializes a configuration class instance for the LLM.
@@ -144,6 +145,8 @@ class BaseLlmConfig(BaseConfig):
         :type query_type: Optional[str], optional
         :param local: If True, the model will be run locally, defaults to False (for huggingface provider)
         :type local: Optional[bool], optional
+        :param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
+        :type default_headers: Optional[Mapping[str, str]], optional
         :raises ValueError: If the template is not valid as template should
         contain $context and $query (and optionally $history)
         :raises ValueError: Stream is not boolean
@@ -173,6 +176,7 @@ class BaseLlmConfig(BaseConfig):
         self.endpoint = endpoint
         self.model_kwargs = model_kwargs
         self.local = local
+        self.default_headers = default_headers
 
         if isinstance(prompt, str):
             prompt = Template(prompt)

+ 3 - 2
embedchain/llm/openai.py

@@ -42,6 +42,8 @@ class OpenAILlm(BaseLlm):
         base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)
         if config.top_p:
             kwargs["model_kwargs"]["top_p"] = config.top_p
+        if config.default_headers:
+            kwargs["default_headers"] = config.default_headers
         if config.stream:
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
             chat = ChatOpenAI(
@@ -65,8 +67,7 @@ class OpenAILlm(BaseLlm):
         messages: list[BaseMessage],
     ) -> str:
         from langchain.output_parsers.openai_tools import JsonOutputToolsParser
-        from langchain_core.utils.function_calling import \
-            convert_to_openai_tool
+        from langchain_core.utils.function_calling import convert_to_openai_tool
 
         openai_tools = [convert_to_openai_tool(tools)]
         chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

+ 1 - 0
embedchain/utils/misc.py

@@ -431,6 +431,7 @@ def validate_config(config_data):
                     Optional("model_kwargs"): dict,
                     Optional("local"): bool,
                     Optional("base_url"): str,
+                    Optional("default_headers"): dict,
                 },
             },
             Optional("vectordb"): {

+ 17 - 0
tests/llm/test_openai.py

@@ -76,6 +76,23 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
         base_url=os.environ["OPENAI_API_BASE"],
     )
 
+def test_get_llm_model_answer_with_special_headers(config, mocker):
+    config.default_headers = {'test': 'test'}
+    mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
+
+    llm = OpenAILlm(config)
+    llm.get_llm_model_answer("Test query")
+
+    mocked_openai_chat.assert_called_once_with(
+        model=config.model,
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        model_kwargs={"top_p": config.top_p},
+        api_key=os.environ["OPENAI_API_KEY"],
+        base_url=os.environ["OPENAI_API_BASE"],
+        default_headers={'test': 'test'}
+    )
+
 
 @pytest.mark.parametrize(
     "mock_return, expected",