瀏覽代碼

refactor storage of model settings

alexchenzl 5 月之前
父節點
當前提交
3e0a71a6f1

+ 19 - 6
chrome-extension/src/background/agent/helper.ts

@@ -1,13 +1,18 @@
-import { type ProviderConfig, LLMProviderEnum, AgentNameEnum } from '@extension/storage';
+import { type ProviderConfig, AgentNameEnum } from '@extension/storage';
 import { ChatOpenAI } from '@langchain/openai';
 import { ChatAnthropic } from '@langchain/anthropic';
 import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
 import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
 
+// Provider constants
+const OPENAI_PROVIDER = 'openai';
+const ANTHROPIC_PROVIDER = 'anthropic';
+const GEMINI_PROVIDER = 'gemini';
+
 // create a chat model based on the agent name, the model name and provider
 export function createChatModel(
   agentName: string,
-  providerName: LLMProviderEnum,
+  providerName: string,
   providerConfig: ProviderConfig,
   modelName: string,
 ): BaseChatModel {
@@ -16,11 +21,19 @@ export function createChatModel(
   let temperature = 0;
   let topP = 0.001;
   switch (providerName) {
-    case LLMProviderEnum.OpenAI: {
+    case OPENAI_PROVIDER: {
       if (agentName === AgentNameEnum.Planner) {
         temperature = 0.02;
       }
-      const args: any = {
+      const args: {
+        model: string;
+        apiKey: string;
+        configuration: Record<string, unknown>;
+        modelKwargs?: { max_completion_tokens: number };
+        topP?: number;
+        temperature?: number;
+        maxTokens?: number;
+      } = {
         model: modelName,
         apiKey: providerConfig.apiKey,
         configuration: {},
@@ -43,7 +56,7 @@ export function createChatModel(
       }
       return new ChatOpenAI(args);
     }
-    case LLMProviderEnum.Anthropic: {
+    case ANTHROPIC_PROVIDER: {
       temperature = 0.1;
       topP = 0.1;
       const args = {
@@ -61,7 +74,7 @@ export function createChatModel(
       }
       return new ChatAnthropic(args);
     }
-    case LLMProviderEnum.Gemini: {
+    case GEMINI_PROVIDER: {
       temperature = 0.5;
       topP = 0.8;
       const args = {

+ 5 - 4
packages/storage/lib/settings/agentModels.ts

@@ -1,11 +1,11 @@
 import { StorageEnum } from '../base/enums';
 import { createStorage } from '../base/base';
 import type { BaseStorage } from '../base/types';
-import { type AgentNameEnum, type LLMProviderEnum, llmProviderModelNames } from './types';
+import { type AgentNameEnum, llmProviderModelNames } from './types';
 
 // Interface for a single model configuration
 export interface ModelConfig {
-  provider: LLMProviderEnum;
+  provider: string;
   modelName: string;
 }
 
@@ -37,8 +37,9 @@ function validateModelConfig(config: ModelConfig) {
     throw new Error('Provider and model name must be specified');
   }
 
-  const validModels = llmProviderModelNames[config.provider];
-  if (!validModels.includes(config.modelName)) {
+  // Check if the provider exists in our predefined providers
+  const validModels = llmProviderModelNames[config.provider as keyof typeof llmProviderModelNames];
+  if (!validModels || !validModels.includes(config.modelName)) {
     throw new Error(`Invalid model "${config.modelName}" for provider "${config.provider}"`);
   }
 }

+ 80 - 20
packages/storage/lib/settings/llmProviders.ts

@@ -1,65 +1,125 @@
 import { StorageEnum } from '../base/enums';
 import { createStorage } from '../base/base';
 import type { BaseStorage } from '../base/types';
-import type { LLMProviderEnum } from './types';
+import { llmProviderModelNames, ProviderTypeEnum, OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GEMINI_PROVIDER } from './types';
 
 // Interface for a single provider configuration
 export interface ProviderConfig {
-  apiKey: string;
-  baseUrl?: string;
+  name?: string; // Display name in the options
+  type?: ProviderTypeEnum; // Help to decide which LangChain ChatModel package to use
+  apiKey: string; // Must be provided, but may be empty for local models
+  baseUrl?: string; // Optional base URL if provided
+  modelNames?: string[]; // Chosen model names, if not provided use hardcoded names from llmProviderModelNames
 }
 
 // Interface for storing multiple LLM provider configurations
 export interface LLMKeyRecord {
-  providers: Record<LLMProviderEnum, ProviderConfig>;
+  providers: Record<string, ProviderConfig>;
 }
 
 export type LLMProviderStorage = BaseStorage<LLMKeyRecord> & {
-  setProvider: (provider: LLMProviderEnum, config: ProviderConfig) => Promise<void>;
-  getProvider: (provider: LLMProviderEnum) => Promise<ProviderConfig | undefined>;
-  removeProvider: (provider: LLMProviderEnum) => Promise<void>;
-  hasProvider: (provider: LLMProviderEnum) => Promise<boolean>;
-  getConfiguredProviders: () => Promise<LLMProviderEnum[]>;
-  getAllProviders: () => Promise<Record<LLMProviderEnum, ProviderConfig>>;
+  setProvider: (provider: string, config: ProviderConfig) => Promise<void>;
+  getProvider: (provider: string) => Promise<ProviderConfig | undefined>;
+  removeProvider: (provider: string) => Promise<void>;
+  hasProvider: (provider: string) => Promise<boolean>;
+  getConfiguredProviders: () => Promise<string[]>;
+  getAllProviders: () => Promise<Record<string, ProviderConfig>>;
 };
 
 const storage = createStorage<LLMKeyRecord>(
   'llm-api-keys',
-  { providers: {} as Record<LLMProviderEnum, ProviderConfig> },
+  { providers: {} },
   {
     storageEnum: StorageEnum.Local,
     liveUpdate: true,
   },
 );
 
+// Helper function to determine provider type from provider name
+function getProviderTypeFromName(provider: string): ProviderTypeEnum {
+  switch (provider) {
+    case OPENAI_PROVIDER:
+      return ProviderTypeEnum.OpenAI;
+    case ANTHROPIC_PROVIDER:
+      return ProviderTypeEnum.Anthropic;
+    case GEMINI_PROVIDER:
+      return ProviderTypeEnum.Gemini;
+    default:
+      return ProviderTypeEnum.CustomOpenAI;
+  }
+}
+
+// Helper function to get display name from provider name
+function getDisplayNameFromProvider(provider: string): string {
+  switch (provider) {
+    case OPENAI_PROVIDER:
+      return 'OpenAI';
+    case ANTHROPIC_PROVIDER:
+      return 'Anthropic';
+    case GEMINI_PROVIDER:
+      return 'Gemini';
+    default:
+      return provider; // Use the provider string as display name for custom providers
+  }
+}
+
 export const llmProviderStore: LLMProviderStorage = {
   ...storage,
-  async setProvider(provider: LLMProviderEnum, config: ProviderConfig) {
+  async setProvider(provider: string, config: ProviderConfig) {
     if (!provider) {
       throw new Error('Provider name cannot be empty');
     }
-    if (!config.apiKey) {
-      throw new Error('API key cannot be empty');
+
+    if (config.apiKey === undefined) {
+      throw new Error('API key must be provided (can be empty for local models)');
+    }
+
+    if (!config.modelNames) {
+      throw new Error('Model names must be provided');
     }
+
+    // Ensure backward compatibility by filling in missing fields
+    const completeConfig: ProviderConfig = {
+      ...config,
+      name: config.name || getDisplayNameFromProvider(provider),
+      type: config.type || getProviderTypeFromName(provider),
+      modelNames: config.modelNames,
+    };
+
     const current = (await storage.get()) || { providers: {} };
     await storage.set({
       providers: {
         ...current.providers,
-        [provider]: config,
+        [provider]: completeConfig,
       },
     });
   },
-  async getProvider(provider: LLMProviderEnum) {
+  async getProvider(provider: string) {
     const data = (await storage.get()) || { providers: {} };
-    return data.providers[provider];
+    const config = data.providers[provider];
+
+    // If we have a config but it's missing some fields, fill them in
+    if (config) {
+      if (!config.name) {
+        config.name = getDisplayNameFromProvider(provider);
+      }
+      if (!config.type) {
+        config.type = getProviderTypeFromName(provider);
+      }
+      if (!config.modelNames) {
+        config.modelNames = llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
+      }
+    }
+
+    return config;
   },
-  async removeProvider(provider: LLMProviderEnum) {
+  async removeProvider(provider: string) {
     const current = (await storage.get()) || { providers: {} };
     const newProviders = { ...current.providers };
     delete newProviders[provider];
     await storage.set({ providers: newProviders });
   },
-  async hasProvider(provider: LLMProviderEnum) {
+  async hasProvider(provider: string) {
     const data = (await storage.get()) || { providers: {} };
     return provider in data.providers;
   },
@@ -74,7 +134,7 @@ export const llmProviderStore: LLMProviderStorage = {
     }
 
     console.log('Configured providers:', data.providers);
-    return Object.keys(data.providers) as LLMProviderEnum[];
+    return Object.keys(data.providers);
   },
   async getAllProviders() {
     const data = await storage.get();

+ 12 - 15
packages/storage/lib/settings/types.ts

@@ -4,30 +4,27 @@ export enum AgentNameEnum {
   Validator = 'validator',
 }
 
-// Enum for supported LLM providers
-export enum LLMProviderEnum {
+// String literal constants for supported LLM providers
+export const OPENAI_PROVIDER = 'openai';
+export const ANTHROPIC_PROVIDER = 'anthropic';
+export const GEMINI_PROVIDER = 'gemini';
+
+// Provider type for determining which LangChain ChatModel package to use
+export enum ProviderTypeEnum {
   OpenAI = 'openai',
   Anthropic = 'anthropic',
   Gemini = 'gemini',
+  CustomOpenAI = 'custom_openai',
 }
 
 export const llmProviderModelNames = {
-  [LLMProviderEnum.OpenAI]: ['gpt-4o', 'gpt-4o-mini', 'o1', 'o1-mini', 'o3-mini', 'deepseek-r1'],
-  [LLMProviderEnum.Anthropic]: ['claude-3-7-sonnet-latest', 'claude-3-5-haiku-latest'],
-  [LLMProviderEnum.Gemini]: [
+  [OPENAI_PROVIDER]: ['gpt-4o', 'gpt-4o-mini', 'o1', 'o1-mini', 'o3-mini', 'deepseek-r1'],
+  [ANTHROPIC_PROVIDER]: ['claude-3-7-sonnet-latest', 'claude-3-5-sonnet-latest', 'claude-3-5-haiku-latest'],
+  [GEMINI_PROVIDER]: [
     'gemini-2.0-flash',
     'gemini-2.0-flash-lite',
     'gemini-2.0-pro-exp-02-05',
     // 'gemini-2.0-flash-thinking-exp-01-21', // TODO: not support function calling for now
   ],
+  // Custom OpenAI providers don't have predefined models as they are user-defined
 };
-
-/**
- * Creates a mapping of LLM model names to their corresponding providers.
- *
- * This function takes the llmProviderModelNames object and converts it into a new object
- * where each model name is mapped to its corresponding provider.
- */
-export const llmModelNamesToProvider = Object.fromEntries(
-  Object.entries(llmProviderModelNames).flatMap(([provider, models]) => models.map(model => [model, provider])),
-);

+ 68 - 58
pages/options/src/components/ModelSettings.tsx

@@ -1,18 +1,17 @@
 import { useEffect, useState } from 'react';
 import { Button } from '@extension/ui';
-import {
-  llmProviderStore,
-  agentModelStore,
-  AgentNameEnum,
-  LLMProviderEnum,
-  llmProviderModelNames,
-} from '@extension/storage';
+import { llmProviderStore, agentModelStore, AgentNameEnum, llmProviderModelNames } from '@extension/storage';
+
+// Provider constants
+const OPENAI_PROVIDER = 'openai';
+const ANTHROPIC_PROVIDER = 'anthropic';
+const GEMINI_PROVIDER = 'gemini';
 
 export const ModelSettings = () => {
-  const [apiKeys, setApiKeys] = useState<Record<LLMProviderEnum, { apiKey: string; baseUrl?: string }>>(
-    {} as Record<LLMProviderEnum, { apiKey: string; baseUrl?: string }>,
+  const [apiKeys, setApiKeys] = useState<Record<string, { apiKey: string; baseUrl?: string }>>(
+    {} as Record<string, { apiKey: string; baseUrl?: string }>,
   );
-  const [modifiedProviders, setModifiedProviders] = useState<Set<LLMProviderEnum>>(new Set());
+  const [modifiedProviders, setModifiedProviders] = useState<Set<string>>(new Set());
   const [selectedModels, setSelectedModels] = useState<Record<AgentNameEnum, string>>({
     [AgentNameEnum.Navigator]: '',
     [AgentNameEnum.Planner]: '',
@@ -24,21 +23,25 @@ export const ModelSettings = () => {
       try {
         const providers = await llmProviderStore.getConfiguredProviders();
 
-        const keys: Record<LLMProviderEnum, { apiKey: string; baseUrl?: string }> = {} as Record<
-          LLMProviderEnum,
+        const keys: Record<string, { apiKey: string; baseUrl?: string }> = {} as Record<
+          string,
           { apiKey: string; baseUrl?: string }
         >;
 
         for (const provider of providers) {
           const config = await llmProviderStore.getProvider(provider);
+          console.log('config', config);
           if (config) {
-            keys[provider] = config;
+            keys[provider] = {
+              apiKey: config.apiKey,
+              baseUrl: config.baseUrl,
+            };
           }
         }
         setApiKeys(keys);
       } catch (error) {
         console.error('Error loading API keys:', error);
-        setApiKeys({} as Record<LLMProviderEnum, { apiKey: string; baseUrl?: string }>);
+        setApiKeys({} as Record<string, { apiKey: string; baseUrl?: string }>);
       }
     };
 
@@ -70,7 +73,7 @@ export const ModelSettings = () => {
     loadAgentModels();
   }, []);
 
-  const handleApiKeyChange = (provider: LLMProviderEnum, apiKey: string, baseUrl?: string) => {
+  const handleApiKeyChange = (provider: string, apiKey: string, baseUrl?: string) => {
     setModifiedProviders(prev => new Set(prev).add(provider));
     setApiKeys(prev => ({
       ...prev,
@@ -81,9 +84,14 @@ export const ModelSettings = () => {
     }));
   };
 
-  const handleSave = async (provider: LLMProviderEnum) => {
+  const handleSave = async (provider: string) => {
     try {
-      await llmProviderStore.setProvider(provider, apiKeys[provider]);
+      // The provider store will handle filling in the missing fields
+      await llmProviderStore.setProvider(provider, {
+        apiKey: apiKeys[provider].apiKey,
+        baseUrl: apiKeys[provider].baseUrl,
+      });
+
       setModifiedProviders(prev => {
         const next = new Set(prev);
         next.delete(provider);
@@ -94,7 +102,7 @@ export const ModelSettings = () => {
     }
   };
 
-  const handleDelete = async (provider: LLMProviderEnum) => {
+  const handleDelete = async (provider: string) => {
     try {
       await llmProviderStore.removeProvider(provider);
       setApiKeys(prev => {
@@ -107,7 +115,7 @@ export const ModelSettings = () => {
     }
   };
 
-  const getButtonProps = (provider: LLMProviderEnum) => {
+  const getButtonProps = (provider: string) => {
     const hasStoredKey = Boolean(apiKeys[provider]?.apiKey);
     const isModified = modifiedProviders.has(provider);
     const hasInput = Boolean(apiKeys[provider]?.apiKey?.trim());
@@ -129,11 +137,14 @@ export const ModelSettings = () => {
 
   const getAvailableModels = () => {
     const models: string[] = [];
-    Object.entries(apiKeys).forEach(([provider, config]) => {
+
+    for (const [provider, config] of Object.entries(apiKeys)) {
       if (config.apiKey) {
-        models.push(...(llmProviderModelNames[provider as LLMProviderEnum] || []));
+        const providerModels = llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
+        models.push(...providerModels);
       }
-    });
+    }
+
     return models.length ? models : [''];
   };
 
@@ -146,10 +157,10 @@ export const ModelSettings = () => {
     try {
       if (model) {
         // Determine provider from model name
-        let provider: LLMProviderEnum | undefined;
+        let provider: string | undefined;
         for (const [providerKey, models] of Object.entries(llmProviderModelNames)) {
           if (models.includes(model)) {
-            provider = providerKey as LLMProviderEnum;
+            provider = providerKey;
             break;
           }
         }
@@ -219,61 +230,59 @@ export const ModelSettings = () => {
             <div className="flex items-center justify-between">
               <h3 className="text-lg font-medium text-gray-700">OpenAI</h3>
               <Button
-                {...getButtonProps(LLMProviderEnum.OpenAI)}
-                size="sm"
+                variant={getButtonProps(OPENAI_PROVIDER).variant}
+                disabled={getButtonProps(OPENAI_PROVIDER).disabled}
                 onClick={() =>
-                  apiKeys[LLMProviderEnum.OpenAI]?.apiKey && !modifiedProviders.has(LLMProviderEnum.OpenAI)
-                    ? handleDelete(LLMProviderEnum.OpenAI)
-                    : handleSave(LLMProviderEnum.OpenAI)
-                }
-              />
+                  apiKeys[OPENAI_PROVIDER]?.apiKey && !modifiedProviders.has(OPENAI_PROVIDER)
+                    ? handleDelete(OPENAI_PROVIDER)
+                    : handleSave(OPENAI_PROVIDER)
+                }>
+                {getButtonProps(OPENAI_PROVIDER).children}
+              </Button>
             </div>
             <div className="space-y-3">
               <input
                 type="password"
                 placeholder="OpenAI API key"
-                value={apiKeys[LLMProviderEnum.OpenAI]?.apiKey || ''}
-                onChange={e => handleApiKeyChange(LLMProviderEnum.OpenAI, e.target.value)}
+                value={apiKeys[OPENAI_PROVIDER]?.apiKey || ''}
+                onChange={e => handleApiKeyChange(OPENAI_PROVIDER, e.target.value)}
                 className="w-full p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
               />
               <input
                 type="text"
                 placeholder="Custom Base URL (Optional)"
-                value={apiKeys[LLMProviderEnum.OpenAI]?.baseUrl || ''}
+                value={apiKeys[OPENAI_PROVIDER]?.baseUrl || ''}
                 onChange={e =>
-                  handleApiKeyChange(
-                    LLMProviderEnum.OpenAI,
-                    apiKeys[LLMProviderEnum.OpenAI]?.apiKey || '',
-                    e.target.value,
-                  )
+                  handleApiKeyChange(OPENAI_PROVIDER, apiKeys[OPENAI_PROVIDER]?.apiKey || '', e.target.value)
                 }
                 className="w-full p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
               />
             </div>
           </div>
 
-          <div className="border-t border-gray-200"></div>
+          <div className="border-t border-gray-200" />
 
           {/* Anthropic Section */}
           <div className="space-y-4">
             <div className="flex items-center justify-between">
               <h3 className="text-lg font-medium text-gray-700">Anthropic</h3>
               <Button
-                {...getButtonProps(LLMProviderEnum.Anthropic)}
-                size="sm"
+                variant={getButtonProps(ANTHROPIC_PROVIDER).variant}
+                disabled={getButtonProps(ANTHROPIC_PROVIDER).disabled}
                 onClick={() =>
-                  apiKeys[LLMProviderEnum.Anthropic]?.apiKey && !modifiedProviders.has(LLMProviderEnum.Anthropic)
-                    ? handleDelete(LLMProviderEnum.Anthropic)
-                    : handleSave(LLMProviderEnum.Anthropic)
-                }
-              />
+                  apiKeys[ANTHROPIC_PROVIDER]?.apiKey && !modifiedProviders.has(ANTHROPIC_PROVIDER)
+                    ? handleDelete(ANTHROPIC_PROVIDER)
+                    : handleSave(ANTHROPIC_PROVIDER)
+                }>
+                {getButtonProps(ANTHROPIC_PROVIDER).children}
+              </Button>
             </div>
             <div className="space-y-3">
               <input
                 type="password"
                 placeholder="Anthropic API key"
-                value={apiKeys[LLMProviderEnum.Anthropic]?.apiKey || ''}
-                onChange={e => handleApiKeyChange(LLMProviderEnum.Anthropic, e.target.value)}
+                value={apiKeys[ANTHROPIC_PROVIDER]?.apiKey || ''}
+                onChange={e => handleApiKeyChange(ANTHROPIC_PROVIDER, e.target.value)}
                 className="w-full p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
               />
             </div>
@@ -286,21 +295,22 @@ export const ModelSettings = () => {
             <div className="flex items-center justify-between">
               <h3 className="text-lg font-medium text-gray-700">Gemini</h3>
               <Button
-                {...getButtonProps(LLMProviderEnum.Gemini)}
-                size="sm"
+                variant={getButtonProps(GEMINI_PROVIDER).variant}
+                disabled={getButtonProps(GEMINI_PROVIDER).disabled}
                 onClick={() =>
-                  apiKeys[LLMProviderEnum.Gemini]?.apiKey && !modifiedProviders.has(LLMProviderEnum.Gemini)
-                    ? handleDelete(LLMProviderEnum.Gemini)
-                    : handleSave(LLMProviderEnum.Gemini)
-                }
-              />
+                  apiKeys[GEMINI_PROVIDER]?.apiKey && !modifiedProviders.has(GEMINI_PROVIDER)
+                    ? handleDelete(GEMINI_PROVIDER)
+                    : handleSave(GEMINI_PROVIDER)
+                }>
+                {getButtonProps(GEMINI_PROVIDER).children}
+              </Button>
             </div>
             <div className="space-y-3">
               <input
                 type="password"
                 placeholder="Gemini API key"
-                value={apiKeys[LLMProviderEnum.Gemini]?.apiKey || ''}
-                onChange={e => handleApiKeyChange(LLMProviderEnum.Gemini, e.target.value)}
+                value={apiKeys[GEMINI_PROVIDER]?.apiKey || ''}
+                onChange={e => handleApiKeyChange(GEMINI_PROVIDER, e.target.value)}
                 className="w-full p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
               />
             </div>