|
@@ -1,56 +1,34 @@
|
|
|
-import { type ProviderConfig, AgentNameEnum, OLLAMA_PROVIDER } from '@extension/storage';
|
|
|
+import { type ProviderConfig, type ModelConfig, ProviderTypeEnum } from '@extension/storage';
|
|
|
import { ChatOpenAI } from '@langchain/openai';
|
|
|
import { ChatAnthropic } from '@langchain/anthropic';
|
|
|
import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
|
|
|
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
|
|
import { ChatOllama } from '@langchain/ollama';
|
|
|
|
|
|
-// Provider constants
|
|
|
-const OPENAI_PROVIDER = 'openai';
|
|
|
-const ANTHROPIC_PROVIDER = 'anthropic';
|
|
|
-const GEMINI_PROVIDER = 'gemini';
|
|
|
-
|
|
|
// create a chat model based on the agent name, the model name and provider
|
|
|
-export function createChatModel(
|
|
|
- agentName: string,
|
|
|
- providerName: string,
|
|
|
- providerConfig: ProviderConfig,
|
|
|
- modelName: string,
|
|
|
-): BaseChatModel {
|
|
|
- const maxTokens = 2000;
|
|
|
- const maxCompletionTokens = 5000;
|
|
|
- let temperature = 0;
|
|
|
- let topP = 0.001;
|
|
|
+export function createChatModel(providerConfig: ProviderConfig, modelConfig: ModelConfig): BaseChatModel {
|
|
|
+ const maxTokens = 1024 * 4;
|
|
|
+ const maxCompletionTokens = 1024 * 4;
|
|
|
+ const temperature = (modelConfig.parameters?.temperature ?? 0.1) as number;
|
|
|
+ const topP = (modelConfig.parameters?.topP ?? 0.1) as number;
|
|
|
|
|
|
- console.log('providerName', providerName);
|
|
|
- console.log('providerConfig', providerConfig);
|
|
|
+ console.log('modelConfig', modelConfig);
|
|
|
|
|
|
- switch (providerName) {
|
|
|
- case OPENAI_PROVIDER: {
|
|
|
- if (agentName === AgentNameEnum.Planner) {
|
|
|
- temperature = 0.02;
|
|
|
- }
|
|
|
+ switch (providerConfig.type) {
|
|
|
+ case ProviderTypeEnum.OpenAI: {
|
|
|
const args: {
|
|
|
model: string;
|
|
|
apiKey: string;
|
|
|
- configuration: Record<string, unknown>;
|
|
|
modelKwargs?: { max_completion_tokens: number };
|
|
|
topP?: number;
|
|
|
temperature?: number;
|
|
|
maxTokens?: number;
|
|
|
} = {
|
|
|
- model: modelName,
|
|
|
+ model: modelConfig.modelName,
|
|
|
apiKey: providerConfig.apiKey,
|
|
|
- configuration: {},
|
|
|
};
|
|
|
- if (providerConfig.baseUrl) {
|
|
|
- args.configuration = {
|
|
|
- baseURL: providerConfig.baseUrl,
|
|
|
- };
|
|
|
- }
|
|
|
-
|
|
|
// O series models have different parameters
|
|
|
- if (modelName.startsWith('o')) {
|
|
|
+ if (modelConfig.modelName.startsWith('o')) {
|
|
|
args.modelKwargs = {
|
|
|
max_completion_tokens: maxCompletionTokens,
|
|
|
};
|
|
@@ -61,11 +39,9 @@ export function createChatModel(
|
|
|
}
|
|
|
return new ChatOpenAI(args);
|
|
|
}
|
|
|
- case ANTHROPIC_PROVIDER: {
|
|
|
- temperature = 0.1;
|
|
|
- topP = 0.1;
|
|
|
+ case ProviderTypeEnum.Anthropic: {
|
|
|
const args = {
|
|
|
- model: modelName,
|
|
|
+ model: modelConfig.modelName,
|
|
|
apiKey: providerConfig.apiKey,
|
|
|
maxTokens,
|
|
|
temperature,
|
|
@@ -79,21 +55,16 @@ export function createChatModel(
|
|
|
}
|
|
|
return new ChatAnthropic(args);
|
|
|
}
|
|
|
- case GEMINI_PROVIDER: {
|
|
|
- temperature = 0.5;
|
|
|
- topP = 0.8;
|
|
|
+ case ProviderTypeEnum.Gemini: {
|
|
|
const args = {
|
|
|
- model: modelName,
|
|
|
+ model: modelConfig.modelName,
|
|
|
apiKey: providerConfig.apiKey,
|
|
|
temperature,
|
|
|
topP,
|
|
|
};
|
|
|
return new ChatGoogleGenerativeAI(args);
|
|
|
}
|
|
|
- case OLLAMA_PROVIDER: {
|
|
|
- if (agentName === AgentNameEnum.Planner) {
|
|
|
- temperature = 0.02;
|
|
|
- }
|
|
|
+ case ProviderTypeEnum.Ollama: {
|
|
|
const args: {
|
|
|
model: string;
|
|
|
apiKey?: string;
|
|
@@ -106,58 +77,36 @@ export function createChatModel(
|
|
|
num_ctx: number;
|
|
|
};
|
|
|
} = {
|
|
|
- model: modelName,
|
|
|
+ model: modelConfig.modelName,
|
|
|
apiKey: providerConfig.apiKey,
|
|
|
baseUrl: providerConfig.baseUrl ?? 'http://localhost:11434',
|
|
|
+ topP,
|
|
|
+ temperature,
|
|
|
+ maxTokens,
|
|
|
options: {
|
|
|
num_ctx: 128000,
|
|
|
},
|
|
|
};
|
|
|
-
|
|
|
- // O series models have different parameters
|
|
|
- if (modelName.startsWith('o')) {
|
|
|
- args.modelKwargs = {
|
|
|
- max_completion_tokens: maxCompletionTokens,
|
|
|
- };
|
|
|
- } else {
|
|
|
- args.topP = topP;
|
|
|
- args.temperature = temperature;
|
|
|
- args.maxTokens = maxTokens;
|
|
|
- }
|
|
|
return new ChatOllama(args);
|
|
|
}
|
|
|
default: {
|
|
|
- if (agentName === AgentNameEnum.Planner) {
|
|
|
- temperature = 0.02;
|
|
|
- }
|
|
|
const args: {
|
|
|
model: string;
|
|
|
apiKey: string;
|
|
|
configuration: Record<string, unknown>;
|
|
|
- modelKwargs?: { max_completion_tokens: number };
|
|
|
topP?: number;
|
|
|
temperature?: number;
|
|
|
maxTokens?: number;
|
|
|
} = {
|
|
|
- model: modelName,
|
|
|
+ model: modelConfig.modelName,
|
|
|
apiKey: providerConfig.apiKey,
|
|
|
- configuration: {},
|
|
|
- };
|
|
|
-
|
|
|
- args.configuration = {
|
|
|
- baseURL: providerConfig.baseUrl,
|
|
|
+ configuration: {
|
|
|
+ baseURL: providerConfig.baseUrl,
|
|
|
+ },
|
|
|
+ topP,
|
|
|
+ temperature,
|
|
|
+ maxTokens,
|
|
|
};
|
|
|
-
|
|
|
- // O series models have different parameters
|
|
|
- if (modelName.startsWith('o')) {
|
|
|
- args.modelKwargs = {
|
|
|
- max_completion_tokens: maxCompletionTokens,
|
|
|
- };
|
|
|
- } else {
|
|
|
- args.topP = topP;
|
|
|
- args.temperature = temperature;
|
|
|
- args.maxTokens = maxTokens;
|
|
|
- }
|
|
|
return new ChatOpenAI(args);
|
|
|
}
|
|
|
}
|