base.ts 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import type { z } from 'zod';
  2. import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
  3. import type { AgentContext, AgentOutput } from '../types';
  4. import type { BasePrompt } from '../prompts/base';
  5. import type { BaseMessage } from '@langchain/core/messages';
  6. import { createLogger } from '@src/background/log';
  7. import type { Action } from '../actions/builder';
  8. import { convertInputMessages, extractJsonFromModelOutput, removeThinkTags } from '../messages/utils';
  9. const logger = createLogger('agent');
  10. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  11. export type CallOptions = Record<string, any>;
  12. // Update options to use Zod schema
  13. export interface BaseAgentOptions {
  14. chatLLM: BaseChatModel;
  15. context: AgentContext;
  16. prompt: BasePrompt;
  17. }
  18. export interface ExtraAgentOptions {
  19. id?: string;
  20. toolCallingMethod?: string;
  21. callOptions?: CallOptions;
  22. }
  23. /**
  24. * Base class for all agents
  25. * @param T - The Zod schema for the model output
  26. * @param M - The type of the result field of the agent output
  27. */
  28. export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
  29. protected id: string;
  30. protected chatLLM: BaseChatModel;
  31. protected prompt: BasePrompt;
  32. protected context: AgentContext;
  33. protected actions: Record<string, Action> = {};
  34. protected modelOutputSchema: T;
  35. protected toolCallingMethod: string | null;
  36. protected chatModelLibrary: string;
  37. protected modelName: string;
  38. protected withStructuredOutput: boolean;
  39. protected callOptions?: CallOptions;
  40. protected modelOutputToolName: string;
  41. declare ModelOutput: z.infer<T>;
  42. constructor(modelOutputSchema: T, options: BaseAgentOptions, extraOptions?: Partial<ExtraAgentOptions>) {
  43. // base options
  44. this.modelOutputSchema = modelOutputSchema;
  45. this.chatLLM = options.chatLLM;
  46. this.prompt = options.prompt;
  47. this.context = options.context;
  48. // TODO: fix this, the name is not correct in production environment
  49. this.chatModelLibrary = this.chatLLM.constructor.name;
  50. this.modelName = this.getModelName();
  51. this.withStructuredOutput = this.setWithStructuredOutput();
  52. // extra options
  53. this.id = extraOptions?.id || 'agent';
  54. this.toolCallingMethod = this.setToolCallingMethod(extraOptions?.toolCallingMethod);
  55. this.callOptions = extraOptions?.callOptions;
  56. this.modelOutputToolName = `${this.id}_output`;
  57. }
  58. // Set the model name
  59. private getModelName(): string {
  60. if ('modelName' in this.chatLLM) {
  61. return this.chatLLM.modelName as string;
  62. }
  63. if ('model_name' in this.chatLLM) {
  64. return this.chatLLM.model_name as string;
  65. }
  66. if ('model' in this.chatLLM) {
  67. return this.chatLLM.model as string;
  68. }
  69. return 'Unknown';
  70. }
  71. // Set the tool calling method
  72. private setToolCallingMethod(toolCallingMethod?: string): string | null {
  73. if (toolCallingMethod === 'auto') {
  74. switch (this.chatModelLibrary) {
  75. case 'ChatGoogleGenerativeAI':
  76. return null;
  77. case 'ChatOpenAI':
  78. case 'AzureChatOpenAI':
  79. case 'ChatGroq':
  80. case 'ChatXAI':
  81. return 'function_calling';
  82. default:
  83. return null;
  84. }
  85. }
  86. return toolCallingMethod || null;
  87. }
  88. // Set whether to use structured output based on the model name
  89. private setWithStructuredOutput(): boolean {
  90. if (this.modelName === 'deepseek-reasoner' || this.modelName === 'deepseek-r1') {
  91. return false;
  92. }
  93. return true;
  94. }
  95. async invoke(inputMessages: BaseMessage[]): Promise<this['ModelOutput']> {
  96. // Use structured output
  97. if (this.withStructuredOutput) {
  98. const structuredLlm = this.chatLLM.withStructuredOutput(this.modelOutputSchema, {
  99. includeRaw: true,
  100. name: this.modelOutputToolName,
  101. });
  102. try {
  103. const response = await structuredLlm.invoke(inputMessages, {
  104. ...this.callOptions,
  105. });
  106. if (response.parsed) {
  107. return response.parsed;
  108. }
  109. logger.error('Failed to parse response', response);
  110. throw new Error('Could not parse response with structured output');
  111. } catch (error) {
  112. const errorMessage = `Failed to invoke ${this.modelName} with structured output: ${error}`;
  113. throw new Error(errorMessage);
  114. }
  115. }
  116. // Without structured output support, need to extract JSON from model output manually
  117. const convertedInputMessages = convertInputMessages(inputMessages, this.modelName);
  118. const response = await this.chatLLM.invoke(convertedInputMessages, {
  119. ...this.callOptions,
  120. });
  121. if (typeof response.content === 'string') {
  122. response.content = removeThinkTags(response.content);
  123. try {
  124. const extractedJson = extractJsonFromModelOutput(response.content);
  125. const parsed = this.validateModelOutput(extractedJson);
  126. if (parsed) {
  127. return parsed;
  128. }
  129. } catch (error) {
  130. const errorMessage = `Failed to extract JSON from response: ${error}`;
  131. throw new Error(errorMessage);
  132. }
  133. }
  134. const errorMessage = `Failed to parse response: ${response}`;
  135. logger.error(errorMessage);
  136. throw new Error('Could not parse response');
  137. }
  138. // Execute the agent and return the result
  139. abstract execute(): Promise<AgentOutput<M>>;
  140. // Helper method to validate metadata
  141. protected validateModelOutput(data: unknown): this['ModelOutput'] | undefined {
  142. if (!this.modelOutputSchema || !data) return undefined;
  143. try {
  144. return this.modelOutputSchema.parse(data);
  145. } catch (error) {
  146. logger.error('validateModelOutput', error);
  147. throw new Error('Could not validate model output');
  148. }
  149. }
  150. }