Procházet zdrojové kódy

[Feature]: Add support for azure openai model (#372)

Deshraj Yadav před 2 roky
rodič
revize
8c91b75b98

+ 2 - 0
docs/advanced/app_types.mdx

@@ -85,11 +85,13 @@ app = CustomApp(config)
     - ANTHPROPIC
     - VERTEX_AI
     - GPT4ALL
+    - AZURE_OPENAI
 - Following embedding functions are available for an embedding function
     - OPENAI
     - HUGGING_FACE
     - VERTEX_AI
     - GPT4ALL
+    - AZURE_OPENAI
 
 
 ### PersonApp

+ 24 - 0
embedchain/apps/CustomApp.py

@@ -64,6 +64,9 @@ class CustomApp(EmbedChain):
             if self.provider == Providers.GPT4ALL:
                 return self.open_source_app._get_gpt4all_answer(prompt, config)
 
+            if self.provider == Providers.AZURE_OPENAI:
+                return CustomApp._get_azure_openai_answer(prompt, config)
+
         except ImportError as e:
             raise ImportError(e.msg) from None
 
@@ -113,6 +116,27 @@ class CustomApp(EmbedChain):
 
         return chat(messages).content
 
+    @staticmethod
+    def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
+        from langchain.chat_models import AzureChatOpenAI
+
+        logging.info(vars(config))
+
+        chat = AzureChatOpenAI(
+            deployment_name="td2",
+            model_name=config.model or "text-davinci-002",
+            temperature=config.temperature,
+            max_tokens=config.max_tokens,
+            streaming=config.stream,
+        )
+
+        if config.top_p and config.top_p != 1:
+            logging.warning("Config option `top_p` is not supported by this model.")
+
+        messages = CustomApp._get_messages(prompt)
+
+        return chat(messages).content
+
     @staticmethod
     def _get_messages(prompt: str) -> List[BaseMessage]:
         from langchain.schema import HumanMessage, SystemMessage

+ 1 - 0
embedchain/models/Providers.py

@@ -6,3 +6,4 @@ class Providers(Enum):
     ANTHROPHIC = "ANTHPROPIC"
     VERTEX_AI = "VERTEX_AI"
     GPT4ALL = "GPT4ALL"
+    AZURE_OPENAI = "AZURE_OPENAI"

+ 2 - 1
tests/vectordb/test_chroma_db.py

@@ -3,10 +3,11 @@
 import unittest
 from unittest.mock import patch
 
+from chromadb.config import Settings
+
 from embedchain import App
 from embedchain.config import AppConfig
 from embedchain.vectordb.chroma_db import ChromaDB, chromadb
-from chromadb.config import Settings
 
 
 class TestChromaDbHosts(unittest.TestCase):