|
@@ -2,9 +2,10 @@ import type { z } from 'zod';
|
|
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
|
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
|
import type { AgentContext, AgentOutput } from '../types';
|
|
import type { AgentContext, AgentOutput } from '../types';
|
|
import type { BasePrompt } from '../prompts/base';
|
|
import type { BasePrompt } from '../prompts/base';
|
|
-import { type BaseMessage, AIMessage, ToolMessage } from '@langchain/core/messages';
|
|
|
|
|
|
+import { type BaseMessage, AIMessage, ToolMessage, HumanMessage } from '@langchain/core/messages';
|
|
import { createLogger } from '@src/background/log';
|
|
import { createLogger } from '@src/background/log';
|
|
import type { Action } from '../actions/builder';
|
|
import type { Action } from '../actions/builder';
|
|
|
|
+import { convertMessagesForNonFunctionCallingModels, mergeSuccessiveMessages } from '../messages/service';
|
|
|
|
|
|
const logger = createLogger('agent');
|
|
const logger = createLogger('agent');
|
|
|
|
|
|
@@ -107,6 +108,27 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
|
|
return text.replace(THINK_TAGS, '');
|
|
return text.replace(THINK_TAGS, '');
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
+ * Convert input messages to a format that is compatible with the model
|
|
|
|
+ * @param inputMessages - The input messages to convert
|
|
|
|
+ * @param modelName - The optional model name to determine conversion strategy
|
|
|
|
+ * @returns The converted input messages
|
|
|
|
+ */
|
|
|
|
+ protected convertInputMessages(inputMessages: BaseMessage[], modelName?: string): BaseMessage[] {
|
|
|
|
+ if (!modelName) {
|
|
|
|
+ return inputMessages;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (modelName === 'deepseek-reasoner' || modelName.startsWith('deepseek-r1')) {
|
|
|
|
+ const convertedInputMessages = convertMessagesForNonFunctionCallingModels(inputMessages);
|
|
|
|
+ let mergedInputMessages = mergeSuccessiveMessages(convertedInputMessages, HumanMessage);
|
|
|
|
+ mergedInputMessages = mergeSuccessiveMessages(mergedInputMessages, AIMessage);
|
|
|
|
+ return mergedInputMessages;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return inputMessages;
|
|
|
|
+ }
|
|
|
|
+
|
|
async invoke(inputMessages: BaseMessage[]): Promise<this['ModelOutput']> {
|
|
async invoke(inputMessages: BaseMessage[]): Promise<this['ModelOutput']> {
|
|
// Use structured output
|
|
// Use structured output
|
|
if (this.withStructuredOutput) {
|
|
if (this.withStructuredOutput) {
|
|
@@ -115,17 +137,25 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
|
|
name: this.modelOutputToolName,
|
|
name: this.modelOutputToolName,
|
|
});
|
|
});
|
|
|
|
|
|
- const response = await structuredLlm.invoke(inputMessages, {
|
|
|
|
- ...this.callOptions,
|
|
|
|
- });
|
|
|
|
- if (response.parsed) {
|
|
|
|
- return response.parsed;
|
|
|
|
|
|
+ try {
|
|
|
|
+ const response = await structuredLlm.invoke(inputMessages, {
|
|
|
|
+ ...this.callOptions,
|
|
|
|
+ });
|
|
|
|
+
|
|
|
|
+ if (response.parsed) {
|
|
|
|
+ return response.parsed;
|
|
|
|
+ }
|
|
|
|
+ logger.error('Failed to parse response', response);
|
|
|
|
+ throw new Error('Could not parse response with structured output');
|
|
|
|
+ } catch (error) {
|
|
|
|
+ const errorMessage = `Failed to invoke ${this.modelName} with structured output: ${error}`;
|
|
|
|
+ throw new Error(errorMessage);
|
|
}
|
|
}
|
|
- throw new Error('Could not parse response');
|
|
|
|
}
|
|
}
|
|
|
|
|
|
// Without structured output support, need to extract JSON from model output manually
|
|
// Without structured output support, need to extract JSON from model output manually
|
|
- const response = await this.chatLLM.invoke(inputMessages, {
|
|
|
|
|
|
+ const convertedInputMessages = this.convertInputMessages(inputMessages, this.modelName);
|
|
|
|
+ const response = await this.chatLLM.invoke(convertedInputMessages, {
|
|
...this.callOptions,
|
|
...this.callOptions,
|
|
});
|
|
});
|
|
if (typeof response.content === 'string') {
|
|
if (typeof response.content === 'string') {
|
|
@@ -137,10 +167,12 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
|
|
return parsed;
|
|
return parsed;
|
|
}
|
|
}
|
|
} catch (error) {
|
|
} catch (error) {
|
|
- logger.error('Could not parse response', response);
|
|
|
|
- throw new Error('Could not parse response');
|
|
|
|
|
|
+ const errorMessage = `Failed to extract JSON from response: ${error}`;
|
|
|
|
+ throw new Error(errorMessage);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ const errorMessage = `Failed to parse response: ${response}`;
|
|
|
|
+ logger.error(errorMessage);
|
|
throw new Error('Could not parse response');
|
|
throw new Error('Could not parse response');
|
|
}
|
|
}
|
|
|
|
|
|
@@ -150,7 +182,12 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
|
|
// Helper method to validate metadata
|
|
// Helper method to validate metadata
|
|
protected validateModelOutput(data: unknown): this['ModelOutput'] | undefined {
|
|
protected validateModelOutput(data: unknown): this['ModelOutput'] | undefined {
|
|
if (!this.modelOutputSchema || !data) return undefined;
|
|
if (!this.modelOutputSchema || !data) return undefined;
|
|
- return this.modelOutputSchema.parse(data);
|
|
|
|
|
|
+ try {
|
|
|
|
+ return this.modelOutputSchema.parse(data);
|
|
|
|
+ } catch (error) {
|
|
|
|
+ logger.error('validateModelOutput', error);
|
|
|
|
+ throw new Error('Could not validate model output');
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
// Add the model output to the memory
|
|
// Add the model output to the memory
|
|
@@ -202,7 +239,7 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
|
|
return JSON.parse(cleanedContent);
|
|
return JSON.parse(cleanedContent);
|
|
} catch (e) {
|
|
} catch (e) {
|
|
logger.warning(`Failed to parse model output: ${content} ${e}`);
|
|
logger.warning(`Failed to parse model output: ${content} ${e}`);
|
|
- throw new Error('Could not parse response.');
|
|
|
|
|
|
+ throw new Error('Failed to extract JSON from model output.');
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|