Forráskód Böngészése

#1155: Add support for OpenAI-compatible endpoint in LLM and Embed (#1197)

Joe 1 éve
szülő
commit
11f4ce8fb6

+ 2 - 0
embedchain/config/embedder/base.py

@@ -11,6 +11,7 @@ class BaseEmbedderConfig:
         deployment_name: Optional[str] = None,
         vector_dimension: Optional[int] = None,
         api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
     ):
         """
         Initialize a new instance of an embedder config class.
@@ -24,3 +25,4 @@ class BaseEmbedderConfig:
         self.deployment_name = deployment_name
         self.vector_dimension = vector_dimension
         self.api_key = api_key
+        self.api_base = api_base

+ 2 - 0
embedchain/config/llm/base.py

@@ -93,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
         query_type: Optional[str] = None,
         callbacks: Optional[list] = None,
         api_key: Optional[str] = None,
+        base_url: Optional[str] = None,
         endpoint: Optional[str] = None,
         model_kwargs: Optional[dict[str, Any]] = None,
         local: Optional[bool] = False,
@@ -167,6 +168,7 @@ class BaseLlmConfig(BaseConfig):
         self.query_type = query_type
         self.callbacks = callbacks
         self.api_key = api_key
+        self.base_url = base_url
         self.endpoint = endpoint
         self.model_kwargs = model_kwargs
         self.local = local

+ 2 - 0
embedchain/embedder/openai.py

@@ -17,6 +17,7 @@ class OpenAIEmbedder(BaseEmbedder):
             self.config.model = "text-embedding-ada-002"
 
         api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
+        api_base = self.config.api_base or os.environ["OPENAI_API_BASE"]
 
         if self.config.deployment_name:
             embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
@@ -28,6 +29,7 @@ class OpenAIEmbedder(BaseEmbedder):
                 )  # noqa:E501
             embedding_fn = OpenAIEmbeddingFunction(
                 api_key=api_key,
+                api_base=api_base,
                 organization_id=os.getenv("OPENAI_ORGANIZATION"),
                 model_name=self.config.model,
             )

+ 9 - 2
embedchain/llm/openai.py

@@ -39,13 +39,20 @@ class OpenAILlm(BaseLlm):
             "model_kwargs": {},
         }
         api_key = config.api_key or os.environ["OPENAI_API_KEY"]
+        base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)
         if config.top_p:
             kwargs["model_kwargs"]["top_p"] = config.top_p
         if config.stream:
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
-            chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
+            chat = ChatOpenAI(
+                **kwargs,
+                streaming=config.stream,
+                callbacks=callbacks,
+                api_key=api_key,
+                base_url=base_url,
+            )
         else:
-            chat = ChatOpenAI(**kwargs, api_key=api_key)
+            chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
         if self.tools:
             return self._query_function_call(chat, self.tools, messages)
 

+ 2 - 0
embedchain/utils/misc.py

@@ -424,6 +424,7 @@ def validate_config(config_data):
                     Optional("where"): dict,
                     Optional("query_type"): str,
                     Optional("api_key"): str,
+                    Optional("base_url"): str,
                     Optional("endpoint"): str,
                     Optional("model_kwargs"): dict,
                     Optional("local"): bool,
@@ -451,6 +452,7 @@ def validate_config(config_data):
                     Optional("model"): Optional(str),
                     Optional("deployment_name"): Optional(str),
                     Optional("api_key"): str,
+                    Optional("api_base"): str,
                     Optional("title"): str,
                     Optional("task_type"): str,
                     Optional("vector_dimension"): int,