|
@@ -64,6 +64,9 @@ class CustomApp(EmbedChain):
|
|
|
if self.provider == Providers.GPT4ALL:
|
|
|
return self.open_source_app._get_gpt4all_answer(prompt, config)
|
|
|
|
|
|
+ if self.provider == Providers.AZURE_OPENAI:
|
|
|
+ return CustomApp._get_azure_openai_answer(prompt, config)
|
|
|
+
|
|
|
except ImportError as e:
|
|
|
raise ImportError(e.msg) from None
|
|
|
|
|
@@ -113,6 +116,27 @@ class CustomApp(EmbedChain):
|
|
|
|
|
|
return chat(messages).content
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
|
|
|
+ from langchain.chat_models import AzureChatOpenAI
|
|
|
+
|
|
|
+ logging.info(vars(config))
|
|
|
+
|
|
|
+ chat = AzureChatOpenAI(
|
|
|
+ deployment_name="td2",
|
|
|
+ model_name=config.model or "text-davinci-002",
|
|
|
+ temperature=config.temperature,
|
|
|
+ max_tokens=config.max_tokens,
|
|
|
+ streaming=config.stream,
|
|
|
+ )
|
|
|
+
|
|
|
+ if config.top_p and config.top_p != 1:
|
|
|
+ logging.warning("Config option `top_p` is not supported by this model.")
|
|
|
+
|
|
|
+ messages = CustomApp._get_messages(prompt)
|
|
|
+
|
|
|
+ return chat(messages).content
|
|
|
+
|
|
|
@staticmethod
|
|
|
def _get_messages(prompt: str) -> List[BaseMessage]:
|
|
|
from langchain.schema import HumanMessage, SystemMessage
|