Pārlūkot izejas kodu

feat: enhance Azure provider handling to support multiple instances and improve configuration validation

Burak Sormageç 3 mēneši atpakaļ
vecāks
revīzija
8cd5beefd2

+ 75 - 1
chrome-extension/src/background/agent/helper.ts

@@ -103,11 +103,19 @@ function extractInstanceNameFromUrl(url: string): string | null {
   return null;
 }
 
+// Function to check if a provider ID is an Azure provider
+function isAzureProvider(providerId: string): boolean {
+  return providerId === ProviderTypeEnum.AzureOpenAI || providerId.startsWith(`${ProviderTypeEnum.AzureOpenAI}_`);
+}
+
 // create a chat model based on the agent name, the model name and provider
 export function createChatModel(providerConfig: ProviderConfig, modelConfig: ModelConfig): BaseChatModel {
   const temperature = (modelConfig.parameters?.temperature ?? 0.1) as number;
   const topP = (modelConfig.parameters?.topP ?? 0.1) as number;
 
+  // Check if the provider is an Azure provider with a custom ID (e.g. azure_openai_2)
+  const isAzure = isAzureProvider(modelConfig.provider);
+
   switch (modelConfig.provider) {
     case ProviderTypeEnum.OpenAI: {
       // Call helper without extra options
@@ -254,7 +262,73 @@ export function createChatModel(providerConfig: ProviderConfig, modelConfig: Mod
       });
     }
     default: {
-      // Handles CustomOpenAI
+      // Check if this is a custom Azure provider (azure_openai_X)
+      if (isAzure) {
+        // Validate necessary fields first
+        if (
+          !providerConfig.baseUrl ||
+          !providerConfig.azureDeploymentNames ||
+          providerConfig.azureDeploymentNames.length === 0 ||
+          !providerConfig.azureApiVersion ||
+          !providerConfig.apiKey
+        ) {
+          throw new Error(
+            'Azure configuration is incomplete. Endpoint, Deployment Name, API Version, and API Key are required. Please check settings.',
+          );
+        }
+
+        // Instead of always using the first deployment name, use the model name from modelConfig
+        // which contains the actual model selected in the UI
+        const deploymentName = modelConfig.modelName;
+
+        // Validate that the selected model exists in the configured deployments
+        if (!providerConfig.azureDeploymentNames.includes(deploymentName)) {
+          console.warn(
+            `[createChatModel] Selected deployment "${deploymentName}" not found in available deployments. ` +
+              `Available: ${JSON.stringify(providerConfig.azureDeploymentNames)}. Using the model anyway.`,
+          );
+        }
+
+        // Extract instance name from the endpoint URL
+        const instanceName = extractInstanceNameFromUrl(providerConfig.baseUrl);
+        if (!instanceName) {
+          throw new Error(
+            `Could not extract Instance Name from Azure Endpoint URL: ${providerConfig.baseUrl}. Expected format like https://<your-instance-name>.openai.azure.com/`,
+          );
+        }
+
+        // Check if the Azure deployment is using an "o" series model (GPT-4o, etc.)
+        const isOSeriesModel = isOpenAIOModel(deploymentName);
+
+        // Use AzureChatOpenAI with specific parameters
+        const args = {
+          azureOpenAIApiInstanceName: instanceName, // Derived from endpoint
+          azureOpenAIApiDeploymentName: deploymentName,
+          azureOpenAIApiKey: providerConfig.apiKey,
+          azureOpenAIApiVersion: providerConfig.azureApiVersion,
+          // For Azure, the model name should be the deployment name itself
+          model: deploymentName, // Set model = deployment name to fix Azure requests
+          // For O series models, use modelKwargs instead of temperature/topP
+          ...(isOSeriesModel
+            ? {
+                modelKwargs: {
+                  max_completion_tokens: maxTokens,
+                  // Add reasoning_effort parameter for Azure o-series models if specified
+                  ...(modelConfig.reasoningEffort ? { reasoning_effort: modelConfig.reasoningEffort } : {}),
+                },
+              }
+            : {
+                temperature,
+                topP,
+                maxTokens,
+              }),
+          // DO NOT pass baseUrl or configuration here
+        };
+        console.log('[createChatModel] Azure args (custom ID) passed to AzureChatOpenAI:', args);
+        return new AzureChatOpenAI(args);
+      }
+
+      // If not Azure, handles CustomOpenAI
       // by default, we think it's a openai-compatible provider
       // Pass undefined for extraFetchOptions for default/custom cases
       console.log('[createChatModel] Calling createOpenAIChatModel for default/custom provider');

+ 11 - 1
packages/storage/lib/settings/llmProviders.ts

@@ -44,6 +44,17 @@ const storage = createStorage<LLMKeyRecord>(
 // Helper function to determine provider type from provider name
 // Make sure to update this function if you add a new provider type
 export function getProviderTypeByProviderId(providerId: string): ProviderTypeEnum {
+  // Check if this is an Azure provider (either the main one or one with a custom ID)
+  if (providerId === ProviderTypeEnum.AzureOpenAI) {
+    return ProviderTypeEnum.AzureOpenAI;
+  }
+
+  // Handle custom Azure providers with IDs like azure_openai_2
+  if (typeof providerId === 'string' && providerId.startsWith(`${ProviderTypeEnum.AzureOpenAI}_`)) {
+    return ProviderTypeEnum.AzureOpenAI;
+  }
+
+  // Handle standard provider types
   switch (providerId) {
     case ProviderTypeEnum.OpenAI:
     case ProviderTypeEnum.Anthropic:
@@ -51,7 +62,6 @@ export function getProviderTypeByProviderId(providerId: string): ProviderTypeEnu
     case ProviderTypeEnum.Gemini:
     case ProviderTypeEnum.Grok:
     case ProviderTypeEnum.Ollama:
-    case ProviderTypeEnum.AzureOpenAI:
     case ProviderTypeEnum.OpenRouter:
       return providerId;
     default:

+ 46 - 5
pages/options/src/components/ModelSettings.tsx

@@ -906,7 +906,7 @@ export const ModelSettings = ({ isDarkMode = false }: ModelSettingsProps) => {
 
     // Sort the filtered providers
     return filteredProviders.sort(([keyA, configA], [keyB, configB]) => {
-      // First, separate newly added providers from stored providers
+      // Separate newly added providers from stored providers
       const isNewA = !providersFromStorage.has(keyA) && modifiedProviders.has(keyA);
       const isNewB = !providersFromStorage.has(keyB) && modifiedProviders.has(keyB);
 
@@ -950,10 +950,50 @@ export const ModelSettings = ({ isDarkMode = false }: ModelSettingsProps) => {
       return;
     }
 
+    // Handle Azure OpenAI specially to allow multiple instances
+    if (providerType === ProviderTypeEnum.AzureOpenAI) {
+      addAzureProvider();
+      return;
+    }
+
     // Handle built-in supported providers
     addBuiltInProvider(providerType);
   };
 
+  // New function to add Azure providers with unique IDs
+  const addAzureProvider = () => {
+    // Count existing Azure providers
+    const azureProviders = Object.keys(providers).filter(
+      key => key === ProviderTypeEnum.AzureOpenAI || key.startsWith(ProviderTypeEnum.AzureOpenAI + '_'),
+    );
+    const nextNumber = azureProviders.length + 1;
+
+    // Create unique ID
+    const providerId =
+      nextNumber === 1 ? ProviderTypeEnum.AzureOpenAI : `${ProviderTypeEnum.AzureOpenAI}_${nextNumber}`;
+
+    // Create config with appropriate name
+    const config = getDefaultProviderConfig(ProviderTypeEnum.AzureOpenAI);
+    config.name = `Azure OpenAI ${nextNumber}`;
+
+    // Add to providers
+    setProviders(prev => ({
+      ...prev,
+      [providerId]: config,
+    }));
+
+    setModifiedProviders(prev => new Set(prev).add(providerId));
+    newlyAddedProviderRef.current = providerId;
+
+    // Scroll to the newly added provider after render
+    setTimeout(() => {
+      const providerElement = document.getElementById(`provider-${providerId}`);
+      if (providerElement) {
+        providerElement.scrollIntoView({ behavior: 'smooth', block: 'center' });
+      }
+    }, 100);
+  };
+
   const getProviderForModel = (modelName: string): string => {
     for (const [provider, config] of Object.entries(providers)) {
       // Check Azure deployment names
@@ -1480,12 +1520,13 @@ export const ModelSettings = ({ isDarkMode = false }: ModelSettingsProps) => {
                 <div className="py-1">
                   {/* Map through provider types to create buttons */}
                   {Object.values(ProviderTypeEnum)
-                    // Filter out CustomOpenAI and already added providers
+                    // Allow Azure to appear multiple times, but filter out other already added providers
                     .filter(
                       type =>
-                        type !== ProviderTypeEnum.CustomOpenAI &&
-                        !providersFromStorage.has(type) &&
-                        !modifiedProviders.has(type),
+                        type === ProviderTypeEnum.AzureOpenAI || // Always show Azure
+                        (type !== ProviderTypeEnum.CustomOpenAI &&
+                          !providersFromStorage.has(type) &&
+                          !modifiedProviders.has(type)),
                     )
                     .map(type => (
                       <button