Quellcode durchsuchen

refactor to allow configure temperature and topP for every chosen model manually

alexchenzl vor 5 Monaten
Ursprung
Commit
f8b599fd96

+ 27 - 78
chrome-extension/src/background/agent/helper.ts

@@ -1,56 +1,34 @@
-import { type ProviderConfig, AgentNameEnum, OLLAMA_PROVIDER } from '@extension/storage';
+import { type ProviderConfig, type ModelConfig, ProviderTypeEnum } from '@extension/storage';
 import { ChatOpenAI } from '@langchain/openai';
 import { ChatAnthropic } from '@langchain/anthropic';
 import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
 import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
 import { ChatOllama } from '@langchain/ollama';
 
-// Provider constants
-const OPENAI_PROVIDER = 'openai';
-const ANTHROPIC_PROVIDER = 'anthropic';
-const GEMINI_PROVIDER = 'gemini';
-
 // create a chat model based on the agent name, the model name and provider
-export function createChatModel(
-  agentName: string,
-  providerName: string,
-  providerConfig: ProviderConfig,
-  modelName: string,
-): BaseChatModel {
-  const maxTokens = 2000;
-  const maxCompletionTokens = 5000;
-  let temperature = 0;
-  let topP = 0.001;
+export function createChatModel(providerConfig: ProviderConfig, modelConfig: ModelConfig): BaseChatModel {
+  const maxTokens = 1024 * 4;
+  const maxCompletionTokens = 1024 * 4;
+  const temperature = (modelConfig.parameters?.temperature ?? 0.1) as number;
+  const topP = (modelConfig.parameters?.topP ?? 0.1) as number;
 
-  console.log('providerName', providerName);
-  console.log('providerConfig', providerConfig);
+  console.log('modelConfig', modelConfig);
 
-  switch (providerName) {
-    case OPENAI_PROVIDER: {
-      if (agentName === AgentNameEnum.Planner) {
-        temperature = 0.02;
-      }
+  switch (providerConfig.type) {
+    case ProviderTypeEnum.OpenAI: {
       const args: {
         model: string;
         apiKey: string;
-        configuration: Record<string, unknown>;
         modelKwargs?: { max_completion_tokens: number };
         topP?: number;
         temperature?: number;
         maxTokens?: number;
       } = {
-        model: modelName,
+        model: modelConfig.modelName,
         apiKey: providerConfig.apiKey,
-        configuration: {},
       };
-      if (providerConfig.baseUrl) {
-        args.configuration = {
-          baseURL: providerConfig.baseUrl,
-        };
-      }
-
       // O series models have different parameters
-      if (modelName.startsWith('o')) {
+      if (modelConfig.modelName.startsWith('o')) {
         args.modelKwargs = {
           max_completion_tokens: maxCompletionTokens,
         };
@@ -61,11 +39,9 @@ export function createChatModel(
       }
       return new ChatOpenAI(args);
     }
-    case ANTHROPIC_PROVIDER: {
-      temperature = 0.1;
-      topP = 0.1;
+    case ProviderTypeEnum.Anthropic: {
       const args = {
-        model: modelName,
+        model: modelConfig.modelName,
         apiKey: providerConfig.apiKey,
         maxTokens,
         temperature,
@@ -79,21 +55,16 @@ export function createChatModel(
       }
       return new ChatAnthropic(args);
     }
-    case GEMINI_PROVIDER: {
-      temperature = 0.5;
-      topP = 0.8;
+    case ProviderTypeEnum.Gemini: {
       const args = {
-        model: modelName,
+        model: modelConfig.modelName,
         apiKey: providerConfig.apiKey,
         temperature,
         topP,
       };
       return new ChatGoogleGenerativeAI(args);
     }
-    case OLLAMA_PROVIDER: {
-      if (agentName === AgentNameEnum.Planner) {
-        temperature = 0.02;
-      }
+    case ProviderTypeEnum.Ollama: {
       const args: {
         model: string;
         apiKey?: string;
@@ -106,58 +77,36 @@ export function createChatModel(
           num_ctx: number;
         };
       } = {
-        model: modelName,
+        model: modelConfig.modelName,
         apiKey: providerConfig.apiKey,
         baseUrl: providerConfig.baseUrl ?? 'http://localhost:11434',
+        topP,
+        temperature,
+        maxTokens,
         options: {
           num_ctx: 128000,
         },
       };
-
-      // O series models have different parameters
-      if (modelName.startsWith('o')) {
-        args.modelKwargs = {
-          max_completion_tokens: maxCompletionTokens,
-        };
-      } else {
-        args.topP = topP;
-        args.temperature = temperature;
-        args.maxTokens = maxTokens;
-      }
       return new ChatOllama(args);
     }
     default: {
-      if (agentName === AgentNameEnum.Planner) {
-        temperature = 0.02;
-      }
       const args: {
         model: string;
         apiKey: string;
         configuration: Record<string, unknown>;
-        modelKwargs?: { max_completion_tokens: number };
         topP?: number;
         temperature?: number;
         maxTokens?: number;
       } = {
-        model: modelName,
+        model: modelConfig.modelName,
         apiKey: providerConfig.apiKey,
-        configuration: {},
-      };
-
-      args.configuration = {
-        baseURL: providerConfig.baseUrl,
+        configuration: {
+          baseURL: providerConfig.baseUrl,
+        },
+        topP,
+        temperature,
+        maxTokens,
       };
-
-      // O series models have different parameters
-      if (modelName.startsWith('o')) {
-        args.modelKwargs = {
-          max_completion_tokens: maxCompletionTokens,
-        };
-      } else {
-        args.topP = topP;
-        args.temperature = temperature;
-        args.maxTokens = maxTokens;
-      }
       return new ChatOpenAI(args);
     }
   }

+ 3 - 18
chrome-extension/src/background/index.ts

@@ -181,33 +181,18 @@ async function setupExecutor(taskId: string, task: string, browserContext: Brows
   if (!navigatorModel) {
     throw new Error('Please choose a model for the navigator in the settings first');
   }
-  const navigatorLLM = createChatModel(
-    AgentNameEnum.Navigator,
-    navigatorModel.provider,
-    providers[navigatorModel.provider],
-    navigatorModel.modelName,
-  );
+  const navigatorLLM = createChatModel(providers[navigatorModel.provider], navigatorModel);
 
   let plannerLLM = null;
   const plannerModel = agentModels[AgentNameEnum.Planner];
   if (plannerModel) {
-    plannerLLM = createChatModel(
-      AgentNameEnum.Planner,
-      plannerModel.provider,
-      providers[plannerModel.provider],
-      plannerModel.modelName,
-    );
+    plannerLLM = createChatModel(providers[plannerModel.provider], plannerModel);
   }
 
   let validatorLLM = null;
   const validatorModel = agentModels[AgentNameEnum.Validator];
   if (validatorModel) {
-    validatorLLM = createChatModel(
-      AgentNameEnum.Validator,
-      validatorModel.provider,
-      providers[validatorModel.provider],
-      validatorModel.modelName,
-    );
+    validatorLLM = createChatModel(providers[validatorModel.provider], validatorModel);
   }
 
   const generalSettings = await generalSettingsStore.getSettings();

+ 30 - 2
packages/storage/lib/settings/agentModels.ts

@@ -2,11 +2,14 @@ import { StorageEnum } from '../base/enums';
 import { createStorage } from '../base/base';
 import type { BaseStorage } from '../base/types';
 import type { AgentNameEnum } from './types';
+import { llmProviderParameters } from './types';
 
 // Interface for a single model configuration
 export interface ModelConfig {
+  // providerId, the key of the provider in the llmProviderStore, not the provider name
   provider: string;
   modelName: string;
+  parameters?: Record<string, unknown>;
 }
 
 // Interface for storing multiple agent model configurations
@@ -38,20 +41,45 @@ function validateModelConfig(config: ModelConfig) {
   }
 }
 
+function getModelParameters(agent: AgentNameEnum, provider: string): Record<string, unknown> {
+  const providerParams = llmProviderParameters[provider as keyof typeof llmProviderParameters]?.[agent];
+  return providerParams ?? { temperature: 0.1, topP: 0.1 };
+}
+
 export const agentModelStore: AgentModelStorage = {
   ...storage,
   setAgentModel: async (agent: AgentNameEnum, config: ModelConfig) => {
     validateModelConfig(config);
+    // Merge default parameters with provided parameters
+    const defaultParams = getModelParameters(agent, config.provider);
+    const mergedConfig = {
+      ...config,
+      parameters: {
+        ...defaultParams,
+        ...config.parameters,
+      },
+    };
     await storage.set(current => ({
       agents: {
         ...current.agents,
-        [agent]: config,
+        [agent]: mergedConfig,
       },
     }));
   },
   getAgentModel: async (agent: AgentNameEnum) => {
     const data = await storage.get();
-    return data.agents[agent];
+    const config = data.agents[agent];
+    if (!config) return undefined;
+
+    // Merge default parameters with stored parameters
+    const defaultParams = getModelParameters(agent, config.provider);
+    return {
+      ...config,
+      parameters: {
+        ...defaultParams,
+        ...config.parameters,
+      },
+    };
   },
   resetAgentModel: async (agent: AgentNameEnum) => {
     await storage.set(current => {

+ 61 - 0
packages/storage/lib/settings/types.ts

@@ -19,6 +19,7 @@ export enum ProviderTypeEnum {
   CustomOpenAI = 'custom_openai',
 }
 
+// Default model names for each built-in provider
 export const llmProviderModelNames = {
   [OPENAI_PROVIDER]: ['gpt-4o', 'gpt-4o-mini', 'o1', 'o1-mini', 'o3-mini'],
   [ANTHROPIC_PROVIDER]: ['claude-3-7-sonnet-latest', 'claude-3-5-sonnet-latest', 'claude-3-5-haiku-latest'],
@@ -31,3 +32,63 @@ export const llmProviderModelNames = {
   [OLLAMA_PROVIDER]: [],
   // Custom OpenAI providers don't have predefined models as they are user-defined
 };
+
+// Default parameters for each agent per provider
+export const llmProviderParameters = {
+  [OPENAI_PROVIDER]: {
+    [AgentNameEnum.Planner]: {
+      temperature: 0.01,
+      topP: 0.001,
+    },
+    [AgentNameEnum.Navigator]: {
+      temperature: 0.01,
+      topP: 0.001,
+    },
+    [AgentNameEnum.Validator]: {
+      temperature: 0.01,
+      topP: 0.001,
+    },
+  },
+  [ANTHROPIC_PROVIDER]: {
+    [AgentNameEnum.Planner]: {
+      temperature: 0.1,
+      topP: 0.1,
+    },
+    [AgentNameEnum.Navigator]: {
+      temperature: 0.1,
+      topP: 0.1,
+    },
+    [AgentNameEnum.Validator]: {
+      temperature: 0.05,
+      topP: 0.1,
+    },
+  },
+  [GEMINI_PROVIDER]: {
+    [AgentNameEnum.Planner]: {
+      temperature: 0.5,
+      topP: 0.8,
+    },
+    [AgentNameEnum.Navigator]: {
+      temperature: 0.5,
+      topP: 0.8,
+    },
+    [AgentNameEnum.Validator]: {
+      temperature: 0.1,
+      topP: 0.8,
+    },
+  },
+  [OLLAMA_PROVIDER]: {
+    [AgentNameEnum.Planner]: {
+      temperature: 0.02,
+      topP: 0.001,
+    },
+    [AgentNameEnum.Navigator]: {
+      temperature: 0.01,
+      topP: 0.001,
+    },
+    [AgentNameEnum.Validator]: {
+      temperature: 0.01,
+      topP: 0.001,
+    },
+  },
+};