123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- import type { z } from 'zod';
- import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
- import type { AgentContext, AgentOutput } from '../types';
- import type { BasePrompt } from '../prompts/base';
- import type { BaseMessage } from '@langchain/core/messages';
- import { createLogger } from '@src/background/log';
- import type { Action } from '../actions/builder';
- import { convertInputMessages, extractJsonFromModelOutput, removeThinkTags } from '../messages/utils';
- const logger = createLogger('agent');
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- export type CallOptions = Record<string, any>;
- // Update options to use Zod schema
- export interface BaseAgentOptions {
- chatLLM: BaseChatModel;
- context: AgentContext;
- prompt: BasePrompt;
- }
- export interface ExtraAgentOptions {
- id?: string;
- toolCallingMethod?: string;
- callOptions?: CallOptions;
- }
- /**
- * Base class for all agents
- * @param T - The Zod schema for the model output
- * @param M - The type of the result field of the agent output
- */
- export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
- protected id: string;
- protected chatLLM: BaseChatModel;
- protected prompt: BasePrompt;
- protected context: AgentContext;
- protected actions: Record<string, Action> = {};
- protected modelOutputSchema: T;
- protected toolCallingMethod: string | null;
- protected chatModelLibrary: string;
- protected modelName: string;
- protected withStructuredOutput: boolean;
- protected callOptions?: CallOptions;
- protected modelOutputToolName: string;
- declare ModelOutput: z.infer<T>;
- constructor(modelOutputSchema: T, options: BaseAgentOptions, extraOptions?: Partial<ExtraAgentOptions>) {
- // base options
- this.modelOutputSchema = modelOutputSchema;
- this.chatLLM = options.chatLLM;
- this.prompt = options.prompt;
- this.context = options.context;
- // TODO: fix this, the name is not correct in production environment
- this.chatModelLibrary = this.chatLLM.constructor.name;
- this.modelName = this.getModelName();
- this.withStructuredOutput = this.setWithStructuredOutput();
- // extra options
- this.id = extraOptions?.id || 'agent';
- this.toolCallingMethod = this.setToolCallingMethod(extraOptions?.toolCallingMethod);
- this.callOptions = extraOptions?.callOptions;
- this.modelOutputToolName = `${this.id}_output`;
- }
- // Set the model name
- private getModelName(): string {
- if ('modelName' in this.chatLLM) {
- return this.chatLLM.modelName as string;
- }
- if ('model_name' in this.chatLLM) {
- return this.chatLLM.model_name as string;
- }
- if ('model' in this.chatLLM) {
- return this.chatLLM.model as string;
- }
- return 'Unknown';
- }
- // Set the tool calling method
- private setToolCallingMethod(toolCallingMethod?: string): string | null {
- if (toolCallingMethod === 'auto') {
- switch (this.chatModelLibrary) {
- case 'ChatGoogleGenerativeAI':
- return null;
- case 'ChatOpenAI':
- case 'AzureChatOpenAI':
- case 'ChatGroq':
- case 'ChatXAI':
- return 'function_calling';
- default:
- return null;
- }
- }
- return toolCallingMethod || null;
- }
- // Set whether to use structured output based on the model name
- private setWithStructuredOutput(): boolean {
- if (this.modelName === 'deepseek-reasoner' || this.modelName === 'deepseek-r1') {
- return false;
- }
- return true;
- }
- async invoke(inputMessages: BaseMessage[]): Promise<this['ModelOutput']> {
- // Use structured output
- if (this.withStructuredOutput) {
- const structuredLlm = this.chatLLM.withStructuredOutput(this.modelOutputSchema, {
- includeRaw: true,
- name: this.modelOutputToolName,
- });
- try {
- const response = await structuredLlm.invoke(inputMessages, {
- ...this.callOptions,
- });
- if (response.parsed) {
- return response.parsed;
- }
- logger.error('Failed to parse response', response);
- throw new Error('Could not parse response with structured output');
- } catch (error) {
- const errorMessage = `Failed to invoke ${this.modelName} with structured output: ${error}`;
- throw new Error(errorMessage);
- }
- }
- // Without structured output support, need to extract JSON from model output manually
- const convertedInputMessages = convertInputMessages(inputMessages, this.modelName);
- const response = await this.chatLLM.invoke(convertedInputMessages, {
- ...this.callOptions,
- });
- if (typeof response.content === 'string') {
- response.content = removeThinkTags(response.content);
- try {
- const extractedJson = extractJsonFromModelOutput(response.content);
- const parsed = this.validateModelOutput(extractedJson);
- if (parsed) {
- return parsed;
- }
- } catch (error) {
- const errorMessage = `Failed to extract JSON from response: ${error}`;
- throw new Error(errorMessage);
- }
- }
- const errorMessage = `Failed to parse response: ${response}`;
- logger.error(errorMessage);
- throw new Error('Could not parse response');
- }
- // Execute the agent and return the result
- abstract execute(): Promise<AgentOutput<M>>;
- // Helper method to validate metadata
- protected validateModelOutput(data: unknown): this['ModelOutput'] | undefined {
- if (!this.modelOutputSchema || !data) return undefined;
- try {
- return this.modelOutputSchema.parse(data);
- } catch (error) {
- logger.error('validateModelOutput', error);
- throw new Error('Could not validate model output');
- }
- }
- }
|