Selaa lähdekoodia

fix: Pass deployment name as param for azure api (#406)

Taranjeet Singh 2 vuotta sitten
vanhempi
commit
1f0f0c93b7

+ 6 - 2
embedchain/apps/CustomApp.py

@@ -118,9 +118,13 @@ class CustomApp(EmbedChain):
     def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
         from langchain.chat_models import AzureChatOpenAI
 
+        if not config.deployment_name:
+            raise ValueError("Deployment name must be provided for Azure OpenAI")
+
         chat = AzureChatOpenAI(
-            deployment_name="td2",
-            model_name=config.model or "text-davinci-002",
+            deployment_name=config.deployment_name,
+            openai_api_version="2023-05-15",
+            model_name=config.model or "gpt-3.5-turbo",
             temperature=config.temperature,
             max_tokens=config.max_tokens,
             streaming=config.stream,

+ 2 - 0
embedchain/config/ChatConfig.py

@@ -33,6 +33,7 @@ class ChatConfig(QueryConfig):
         max_tokens=None,
         top_p=None,
         stream: bool = False,
+        deployment_name=None,
     ):
         """
         Initializes the ChatConfig instance.
@@ -68,6 +69,7 @@ class ChatConfig(QueryConfig):
             top_p=top_p,
             history=[0],
             stream=stream,
+            deployment_name=deployment_name,
         )
 
     def set_history(self, history):

+ 2 - 0
embedchain/config/QueryConfig.py

@@ -62,6 +62,7 @@ class QueryConfig(BaseConfig):
         top_p=None,
         history=None,
         stream: bool = False,
+        deployment_name=None,
     ):
         """
         Initializes the QueryConfig instance.
@@ -106,6 +107,7 @@ class QueryConfig(BaseConfig):
         self.max_tokens = max_tokens if max_tokens else 1000
         self.model = model
         self.top_p = top_p if top_p else 1
+        self.deployment_name = deployment_name
 
         if self.validate_template(template):
             self.template = template

+ 10 - 3
embedchain/config/apps/CustomAppConfig.py

@@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig):
         provider: Providers = None,
         model=None,
         open_source_app_config=None,
+        deployment_name=None,
     ):
         """
         :param log_level: Optional. (String) Debug level
@@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig):
 
         super().__init__(
             log_level=log_level,
-            embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
+            embedding_fn=CustomAppConfig.embedding_function(
+                embedding_function=embedding_fn, model=embedding_fn_model,
+                deployment_name=deployment_name
+            ),
             db=db,
             host=host,
             port=port,
@@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig):
         return embed_function
 
     @staticmethod
-    def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
+    def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
         if not isinstance(embedding_function, EmbeddingFunctions):
             raise ValueError(
                 f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}"  # noqa: E501
@@ -80,7 +84,10 @@ class CustomAppConfig(BaseAppConfig):
             if model:
                 embeddings = OpenAIEmbeddings(model=model)
             else:
-                embeddings = OpenAIEmbeddings()
+                if deployment_name:
+                    embeddings = OpenAIEmbeddings(deployment=deployment_name)
+                else:
+                    embeddings = OpenAIEmbeddings()
             return CustomAppConfig.langchain_default_concept(embeddings)
 
         elif embedding_function == EmbeddingFunctions.HUGGING_FACE: