فهرست منبع

upgrade messages

alexchenzl 3 ماه پیش
والد
کامیت
0d2f9a11af

+ 5 - 86
chrome-extension/src/background/agent/agents/base.ts

@@ -2,10 +2,10 @@ import type { z } from 'zod';
 import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
 import type { AgentContext, AgentOutput } from '../types';
 import type { BasePrompt } from '../prompts/base';
-import { type BaseMessage, AIMessage, ToolMessage, HumanMessage } from '@langchain/core/messages';
+import type { BaseMessage } from '@langchain/core/messages';
 import { createLogger } from '@src/background/log';
 import type { Action } from '../actions/builder';
-import { convertMessagesForNonFunctionCallingModels, mergeSuccessiveMessages } from '../messages/service';
+import { convertInputMessages, extractJsonFromModelOutput, removeThinkTags } from '../messages/utils';
 
 const logger = createLogger('agent');
 
@@ -24,8 +24,6 @@ export interface ExtraAgentOptions {
   callOptions?: CallOptions;
 }
 
-const THINK_TAGS = /<think>[\s\S]*?<\/think>/;
-
 /**
  * Base class for all agents
  * @param T - The Zod schema for the model output
@@ -103,32 +101,6 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
     return true;
   }
 
-  // Remove think tags from the model output
-  protected removeThinkTags(text: string): string {
-    return text.replace(THINK_TAGS, '');
-  }
-
-  /**
-   * Convert input messages to a format that is compatible with the model
-   * @param inputMessages - The input messages to convert
-   * @param modelName - The optional model name to determine conversion strategy
-   * @returns The converted input messages
-   */
-  protected convertInputMessages(inputMessages: BaseMessage[], modelName?: string): BaseMessage[] {
-    if (!modelName) {
-      return inputMessages;
-    }
-
-    if (modelName === 'deepseek-reasoner' || modelName.startsWith('deepseek-r1')) {
-      const convertedInputMessages = convertMessagesForNonFunctionCallingModels(inputMessages);
-      let mergedInputMessages = mergeSuccessiveMessages(convertedInputMessages, HumanMessage);
-      mergedInputMessages = mergeSuccessiveMessages(mergedInputMessages, AIMessage);
-      return mergedInputMessages;
-    }
-
-    return inputMessages;
-  }
-
   async invoke(inputMessages: BaseMessage[]): Promise<this['ModelOutput']> {
     // Use structured output
     if (this.withStructuredOutput) {
@@ -154,14 +126,14 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
     }
 
     // Without structured output support, need to extract JSON from model output manually
-    const convertedInputMessages = this.convertInputMessages(inputMessages, this.modelName);
+    const convertedInputMessages = convertInputMessages(inputMessages, this.modelName);
     const response = await this.chatLLM.invoke(convertedInputMessages, {
       ...this.callOptions,
     });
     if (typeof response.content === 'string') {
-      response.content = this.removeThinkTags(response.content);
+      response.content = removeThinkTags(response.content);
       try {
-        const extractedJson = this.extractJsonFromModelOutput(response.content);
+        const extractedJson = extractJsonFromModelOutput(response.content);
         const parsed = this.validateModelOutput(extractedJson);
         if (parsed) {
           return parsed;
@@ -189,57 +161,4 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
       throw new Error('Could not validate model output');
     }
   }
-
-  // Add the model output to the memory
-  protected addModelOutputToMemory(modelOutput: this['ModelOutput']): void {
-    const messageManager = this.context.messageManager;
-    const toolCallId = String(messageManager.nextToolId());
-    const toolCalls = [
-      {
-        name: this.modelOutputToolName,
-        args: modelOutput,
-        id: toolCallId,
-        type: 'tool_call' as const,
-      },
-    ];
-
-    const toolCallMessage = new AIMessage({
-      content: 'tool call',
-      tool_calls: toolCalls,
-    });
-    messageManager.addMessageWithTokens(toolCallMessage);
-
-    const toolMessage = new ToolMessage({
-      content: 'tool call response placeholder',
-      tool_call_id: toolCallId,
-    });
-    messageManager.addMessageWithTokens(toolMessage);
-  }
-
-  /**
-   * Extract JSON from raw string model output, handling both plain JSON and code-block-wrapped JSON.
-   *
-   * some models not supporting tool calls well like deepseek-reasoner, so we need to extract the JSON from the output
-   * @param content - The content of the model output
-   * @returns The JSON object
-   */
-  protected extractJsonFromModelOutput(content: string): unknown {
-    try {
-      let cleanedContent = content;
-      // If content is wrapped in code blocks, extract just the JSON part
-      if (content.includes('```')) {
-        // Find the JSON content between code blocks
-        cleanedContent = cleanedContent.split('```')[1];
-        // Remove language identifier if present (e.g., 'json\n')
-        if (cleanedContent.includes('json\n')) {
-          cleanedContent = cleanedContent.replace(/^json\s*/, '');
-        }
-      }
-      // Parse the cleaned content
-      return JSON.parse(cleanedContent);
-    } catch (e) {
-      logger.warning(`Failed to parse model output: ${content} ${e}`);
-      throw new Error('Failed to extract JSON from model output.');
-    }
-  }
 }

+ 8 - 2
chrome-extension/src/background/agent/agents/navigator.ts

@@ -201,7 +201,6 @@ export class NavigatorAgent extends BaseAgent<z.ZodType, NavigatorResult> {
     }
 
     const messageManager = this.context.messageManager;
-    const options = this.context.options;
     // Handle results that should be included in memory
     if (this.context.actionResults.length > 0) {
       let index = 0;
@@ -213,11 +212,18 @@ export class NavigatorAgent extends BaseAgent<z.ZodType, NavigatorResult> {
             messageManager.addMessageWithTokens(msg);
           }
           if (r.error) {
-            const msg = new HumanMessage(`Action error: ${r.error.toString().slice(-options.maxErrorLength)}`);
+            // Get error text and convert to string
+            const errorText = r.error.toString().trim();
+
+            // Get only the last line of the error
+            const lastLine = errorText.split('\n').pop() || '';
+
+            const msg = new HumanMessage(`Action error: ${lastLine}`);
             logger.info('Adding action error to memory', msg.content);
             messageManager.addMessageWithTokens(msg);
           }
           // reset this action result to empty, we dont want to add it again in the state message
+          // NOTE: in python version, all action results are reset to empty, but in ts version, only those included in memory are reset to empty
           this.context.actionResults[index] = new ActionResult();
         }
         index++;

+ 132 - 141
chrome-extension/src/background/agent/messages/service.ts

@@ -1,79 +1,103 @@
-import { type BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage } from '@langchain/core/messages';
-import { MessageHistory, type MessageMetadata, type ManagedMessage } from '@src/background/agent/messages/views';
+import { type BaseMessage, AIMessage, HumanMessage, type SystemMessage, ToolMessage } from '@langchain/core/messages';
+import { MessageHistory, MessageMetadata, ManagedMessage } from '@src/background/agent/messages/views';
 import { createLogger } from '@src/background/log';
 
 const logger = createLogger('MessageManager');
 
+export class MessageManagerSettings {
+  maxInputTokens = 128000;
+  estimatedCharactersPerToken = 3;
+  imageTokens = 800;
+  includeAttributes: string[] = [];
+  messageContext?: string;
+  sensitiveData?: Record<string, string>;
+  availableFilePaths?: string[];
+
+  constructor(
+    options: {
+      maxInputTokens?: number;
+      estimatedCharactersPerToken?: number;
+      imageTokens?: number;
+      includeAttributes?: string[];
+      messageContext?: string;
+      sensitiveData?: Record<string, string>;
+      availableFilePaths?: string[];
+    } = {},
+  ) {
+    if (options.maxInputTokens !== undefined) this.maxInputTokens = options.maxInputTokens;
+    if (options.estimatedCharactersPerToken !== undefined)
+      this.estimatedCharactersPerToken = options.estimatedCharactersPerToken;
+    if (options.imageTokens !== undefined) this.imageTokens = options.imageTokens;
+    if (options.includeAttributes !== undefined) this.includeAttributes = options.includeAttributes;
+    if (options.messageContext !== undefined) this.messageContext = options.messageContext;
+    if (options.sensitiveData !== undefined) this.sensitiveData = options.sensitiveData;
+    if (options.availableFilePaths !== undefined) this.availableFilePaths = options.availableFilePaths;
+  }
+}
+
 export default class MessageManager {
-  private maxInputTokens: number;
   private history: MessageHistory;
-  private estimatedCharactersPerToken: number;
-  private readonly IMG_TOKENS: number;
-  private sensitiveData?: Record<string, string>;
   private toolId: number;
+  private settings: MessageManagerSettings;
 
-  constructor({
-    maxInputTokens = 128000,
-    estimatedCharactersPerToken = 3,
-    imageTokens = 800,
-    sensitiveData,
-  }: {
-    maxInputTokens?: number;
-    estimatedCharactersPerToken?: number;
-    imageTokens?: number;
-    sensitiveData?: Record<string, string>;
-  } = {}) {
-    this.maxInputTokens = maxInputTokens;
+  constructor(settings: MessageManagerSettings = new MessageManagerSettings()) {
+    this.settings = settings;
     this.history = new MessageHistory();
-    this.estimatedCharactersPerToken = estimatedCharactersPerToken;
-    this.IMG_TOKENS = imageTokens;
-    this.sensitiveData = sensitiveData;
     this.toolId = 1;
   }
 
   public initTaskMessages(systemMessage: SystemMessage, task: string, messageContext?: string): void {
     // Add system message
-    this.addMessageWithTokens(systemMessage);
+    this.addMessageWithTokens(systemMessage, 'init');
 
     // Add context message if provided
     if (messageContext && messageContext.length > 0) {
       const contextMessage = new HumanMessage({
         content: `Context for the task: ${messageContext}`,
       });
-      this.addMessageWithTokens(contextMessage);
+      this.addMessageWithTokens(contextMessage, 'init');
     }
 
     // Add task instructions
     const taskMessage = MessageManager.taskInstructions(task);
-    this.addMessageWithTokens(taskMessage);
+    this.addMessageWithTokens(taskMessage, 'init');
 
     // Add sensitive data info if sensitive data is provided
-    if (this.sensitiveData) {
-      const info = `Here are placeholders for sensitive data: ${Object.keys(this.sensitiveData)}`;
+    if (this.settings.sensitiveData) {
+      const info = `Here are placeholders for sensitive data: ${Object.keys(this.settings.sensitiveData)}`;
       const infoMessage = new HumanMessage({
         content: `${info}\nTo use them, write <secret>the placeholder name</secret>`,
       });
-      this.addMessageWithTokens(infoMessage);
+      this.addMessageWithTokens(infoMessage, 'init');
     }
 
     // Add example output
     const placeholderMessage = new HumanMessage({
       content: 'Example output:',
     });
-    this.addMessageWithTokens(placeholderMessage);
+    this.addMessageWithTokens(placeholderMessage, 'init');
 
     const toolCallId = this.nextToolId();
     const toolCalls = [
       {
-        name: 'navigator_output',
+        name: 'AgentOutput',
         args: {
           current_state: {
-            page_summary: 'On the page are company a,b,c wtih their revenue 1,2,3.',
-            evaluation_previous_goal: 'Success - I opend the first page',
-            memory: 'Starting with the new task. I have completed 1/10 steps',
-            next_goal: 'Click on company a',
+            evaluation_previous_goal:
+              `Success - I successfully clicked on the 'Apple' link from the Google Search results page, 
+              which directed me to the 'Apple' company homepage. This is a good start toward finding 
+              the best place to buy a new iPhone as the Apple website often list iPhones for sale.`.trim(),
+            memory: `I searched for 'iPhone retailers' on Google. From the Google Search results page, 
+              I used the 'click_element' tool to click on a element labelled 'Best Buy' but calling 
+              the tool did not direct me to a new page. I then used the 'click_element' tool to click 
+              on a element labelled 'Apple' which redirected me to the 'Apple' company homepage. 
+              Currently at step 3/15.`.trim(),
+            next_goal: `Looking at reported structure of the current page, I can see the item '[127]<h3 iPhone/>' 
+              in the content. I think this button will lead to more information and potentially prices 
+              for iPhones. I'll click on the link to 'iPhone' at index [127] using the 'click_element' 
+              tool and hope to see prices on the next page.`.trim(),
           },
-          action: [{ click_element: { index: 0 } }],
+          action: [{ click_element: { index: 127 } }],
         },
         id: String(toolCallId),
         type: 'tool_call' as const,
@@ -81,22 +105,25 @@ export default class MessageManager {
     ];
 
     const exampleToolCall = new AIMessage({
-      content: 'example tool call',
+      content: '',
       tool_calls: toolCalls,
     });
-    this.addMessageWithTokens(exampleToolCall);
-
-    const toolMessage = new ToolMessage({
-      content: 'Browser started',
-      tool_call_id: String(toolCallId),
-    });
-    this.addMessageWithTokens(toolMessage);
+    this.addMessageWithTokens(exampleToolCall, 'init');
+    this.addToolMessage('Browser started', toolCallId, 'init');
 
     // Add history start marker
     const historyStartMessage = new HumanMessage({
       content: '[Your task history memory starts here]',
     });
     this.addMessageWithTokens(historyStartMessage);
+
+    // Add available file paths if provided
+    if (this.settings.availableFilePaths && this.settings.availableFilePaths.length > 0) {
+      const filepathsMsg = new HumanMessage({
+        content: `Here are file paths you can use: ${this.settings.availableFilePaths}`,
+      });
+      this.addMessageWithTokens(filepathsMsg, 'init');
+    }
   }
 
   public nextToolId(): number {
@@ -141,7 +168,7 @@ export default class MessageManager {
   public addPlan(plan?: string, position?: number): void {
     if (plan) {
       const msg = new AIMessage({ content: plan });
-      this.addMessageWithTokens(msg, position);
+      this.addMessageWithTokens(msg, null, position);
     }
   }
 
@@ -153,11 +180,37 @@ export default class MessageManager {
     this.addMessageWithTokens(stateMessage);
   }
 
+  /**
+   * Adds a model output message to the history
+   * @param modelOutput - The model output
+   */
+  public addModelOutput(modelOutput: Record<string, any>): void {
+    const toolCallId = this.nextToolId();
+    const toolCalls = [
+      {
+        name: 'AgentOutput',
+        args: modelOutput,
+        id: String(toolCallId),
+        type: 'tool_call' as const,
+      },
+    ];
+
+    const msg = new AIMessage({
+      content: 'tool call',
+      tool_calls: toolCalls,
+    });
+    this.addMessageWithTokens(msg);
+
+    // Need a placeholder for the tool response here to avoid errors sometimes
+    // NOTE: in browser-use, it uses an empty string
+    this.addToolMessage('tool call response placeholder', toolCallId);
+  }
+
   /**
    * Removes the last state message from the history
    */
   public removeLastStateMessage(): void {
-    this.history.removeLastHumanMessage();
+    this.history.removeLastStateMessage();
   }
 
   public getMessages(): BaseMessage[] {
@@ -167,32 +220,29 @@ export default class MessageManager {
     logger.debug(`Messages in history: ${this.history.messages.length}:`);
 
     for (const m of this.history.messages) {
-      totalInputTokens += m.metadata.inputTokens;
-      logger.debug(`${m.message.constructor.name} - Token count: ${m.metadata.inputTokens}`);
+      totalInputTokens += m.metadata.tokens;
+      logger.debug(`${m.message.constructor.name} - Token count: ${m.metadata.tokens}`);
     }
 
     logger.debug(`Total input tokens: ${totalInputTokens}`);
     return messages;
   }
 
-  public getMessagesWithTokens(): ManagedMessage[] {
-    return this.history.messages;
-  }
-
   /**
    * Adds a message to the history with the token count metadata
    * @param message - The BaseMessage object to add
+   * @param messageType - The type of the message (optional)
    * @param position - The optional position to add the message, if not provided, the message will be added to the end of the history
    */
-  public addMessageWithTokens(message: BaseMessage, position?: number): void {
+  public addMessageWithTokens(message: BaseMessage, messageType?: string | null, position?: number): void {
     let filteredMessage = message;
     // filter out sensitive data if provided
-    if (this.sensitiveData) {
+    if (this.settings.sensitiveData) {
       filteredMessage = this._filterSensitiveData(message);
     }
 
     const tokenCount = this._countTokens(filteredMessage);
-    const metadata: MessageMetadata = { inputTokens: tokenCount };
+    const metadata: MessageMetadata = new MessageMetadata(tokenCount, messageType);
     this.history.addMessage(filteredMessage, metadata, position);
   }
 
@@ -204,9 +254,11 @@ export default class MessageManager {
   private _filterSensitiveData(message: BaseMessage): BaseMessage {
     const replaceSensitive = (value: string): string => {
       let filteredValue = value;
-      if (!this.sensitiveData) return filteredValue;
+      if (!this.settings.sensitiveData) return filteredValue;
 
-      for (const [key, val] of Object.entries(this.sensitiveData)) {
+      for (const [key, val] of Object.entries(this.settings.sensitiveData)) {
+        // Skip empty values to match Python behavior
+        if (!val) continue;
         filteredValue = filteredValue.replace(val, `<secret>${key}</secret>`);
       }
       return filteredValue;
@@ -216,7 +268,8 @@ export default class MessageManager {
       message.content = replaceSensitive(message.content);
     } else if (Array.isArray(message.content)) {
       message.content = message.content.map(item => {
-        if (typeof item === 'object' && 'text' in item) {
+        // Add null check to match Python's isinstance() behavior
+        if (typeof item === 'object' && item !== null && 'text' in item) {
           return { ...item, text: replaceSensitive(item.text) };
         }
         return item;
@@ -237,7 +290,7 @@ export default class MessageManager {
     if (Array.isArray(message.content)) {
       for (const item of message.content) {
         if ('image_url' in item) {
-          tokens += this.IMG_TOKENS;
+          tokens += this.settings.imageTokens;
         } else if (typeof item === 'object' && 'text' in item) {
           tokens += this._countTextTokens(item.text);
         }
@@ -261,7 +314,7 @@ export default class MessageManager {
    * @returns The number of tokens in the text
    */
   private _countTextTokens(text: string): number {
-    return Math.floor(text.length / this.estimatedCharactersPerToken);
+    return Math.floor(text.length / this.settings.estimatedCharactersPerToken);
   }
 
   /**
@@ -270,7 +323,7 @@ export default class MessageManager {
    * Get current message list, potentially trimmed to max tokens
    */
   public cutMessages(): void {
-    let diff = this.history.totalTokens - this.maxInputTokens;
+    let diff = this.history.totalTokens - this.settings.maxInputTokens;
     if (diff <= 0) return;
 
     const lastMsg = this.history.messages[this.history.messages.length - 1];
@@ -280,11 +333,11 @@ export default class MessageManager {
       let text = '';
       lastMsg.message.content = lastMsg.message.content.filter(item => {
         if ('image_url' in item) {
-          diff -= this.IMG_TOKENS;
-          lastMsg.metadata.inputTokens -= this.IMG_TOKENS;
-          this.history.totalTokens -= this.IMG_TOKENS;
+          diff -= this.settings.imageTokens;
+          lastMsg.metadata.tokens -= this.settings.imageTokens;
+          this.history.totalTokens -= this.settings.imageTokens;
           logger.debug(
-            `Removed image with ${this.IMG_TOKENS} tokens - total tokens now: ${this.history.totalTokens}/${this.maxInputTokens}`,
+            `Removed image with ${this.settings.imageTokens} tokens - total tokens now: ${this.history.totalTokens}/${this.settings.maxInputTokens}`,
           );
           return false;
         }
@@ -301,104 +354,42 @@ export default class MessageManager {
 
     // if still over, remove text from state message proportionally to the number of tokens needed with buffer
     // Calculate the proportion of content to remove
-    const proportionToRemove = diff / lastMsg.metadata.inputTokens;
+    const proportionToRemove = diff / lastMsg.metadata.tokens;
     if (proportionToRemove > 0.99) {
       throw new Error(
         `Max token limit reached - history is too long - reduce the system prompt or task. proportion_to_remove: ${proportionToRemove}`,
       );
     }
     logger.debug(
-      `Removing ${(proportionToRemove * 100).toFixed(2)}% of the last message (${(proportionToRemove * lastMsg.metadata.inputTokens).toFixed(2)} / ${lastMsg.metadata.inputTokens.toFixed(2)} tokens)`,
+      `Removing ${(proportionToRemove * 100).toFixed(2)}% of the last message (${(proportionToRemove * lastMsg.metadata.tokens).toFixed(2)} / ${lastMsg.metadata.tokens.toFixed(2)} tokens)`,
     );
 
     const content = lastMsg.message.content as string;
     const charactersToRemove = Math.floor(content.length * proportionToRemove);
     const newContent = content.slice(0, -charactersToRemove);
 
-    this.history.removeMessage(-1);
+    // remove tokens and old long message
+    this.history.removeLastStateMessage();
 
+    // new message with updated content
     const msg = new HumanMessage({ content: newContent });
     this.addMessageWithTokens(msg);
 
     const finalMsg = this.history.messages[this.history.messages.length - 1];
     logger.debug(
-      `Added message with ${finalMsg.metadata.inputTokens} tokens - total tokens now: ${this.history.totalTokens}/${this.maxInputTokens} - total messages: ${this.history.messages.length}`,
+      `Added message with ${finalMsg.metadata.tokens} tokens - total tokens now: ${this.history.totalTokens}/${this.settings.maxInputTokens} - total messages: ${this.history.messages.length}`,
     );
   }
-}
-
-/**
- * Converts messages for non-function-calling models
- * @param inputMessages - The BaseMessage objects to convert
- * @returns The converted BaseMessage objects
- */
-export function convertMessagesForNonFunctionCallingModels(inputMessages: BaseMessage[]): BaseMessage[] {
-  return inputMessages.map(message => {
-    if (message instanceof HumanMessage || message instanceof SystemMessage) {
-      return message;
-    }
-    if (message instanceof ToolMessage) {
-      return new HumanMessage({
-        content: `Tool Response: ${message.content}`,
-      });
-    }
-    if (message instanceof AIMessage) {
-      // if it's an AIMessage with tool_calls, convert it to a normal AIMessage
-      if ('tool_calls' in message && message.tool_calls) {
-        const toolCallsStr = message.tool_calls
-          .map(tc => {
-            if (
-              'function' in tc &&
-              typeof tc.function === 'object' &&
-              tc.function &&
-              'name' in tc.function &&
-              'arguments' in tc.function
-            ) {
-              // For Groq, we need to format function calls differently
-              return `Function: ${tc.function.name}\nArguments: ${JSON.stringify(tc.function.arguments)}`;
-            }
-            return `Tool Call: ${JSON.stringify(tc)}`;
-          })
-          .join('\n');
-        return new AIMessage({ content: toolCallsStr });
-      }
-      return message;
-    }
-    throw new Error(`Unknown message type: ${message.constructor.name}`);
-  });
-}
 
-/**
- * Some models like deepseek-reasoner dont allow multiple human messages in a row. This function merges them into one."
- * @param messages - The BaseMessage objects to merge
- * @param classToMerge - The class of the messages to merge
- * @returns The merged BaseMessage objects
- */
-export function mergeSuccessiveMessages(messages: BaseMessage[], classToMerge: typeof BaseMessage): BaseMessage[] {
-  const mergedMessages: BaseMessage[] = [];
-  let streak = 0;
-
-  for (const message of messages) {
-    if (message instanceof classToMerge) {
-      streak += 1;
-      if (streak > 1) {
-        const lastMessage = mergedMessages[mergedMessages.length - 1];
-        if (Array.isArray(message.content)) {
-          const firstContent = message.content[0];
-          if ('text' in firstContent) {
-            lastMessage.content += firstContent.text;
-          }
-        } else {
-          lastMessage.content += message.content;
-        }
-      } else {
-        mergedMessages.push(message);
-      }
-    } else {
-      mergedMessages.push(message);
-      streak = 0;
-    }
+  /**
+   * Adds a tool message to the history
+   * @param content - The content of the tool message
+   * @param toolCallId - The tool call id of the tool message, if not provided, a new tool call id will be generated
+   * @param messageType - The type of the tool message
+   */
+  public addToolMessage(content: string, toolCallId?: number, messageType?: string | null): void {
+    const id = toolCallId ?? this.nextToolId();
+    const msg = new ToolMessage({ content, tool_call_id: String(id) });
+    this.addMessageWithTokens(msg, messageType);
   }
-
-  return mergedMessages;
 }

+ 139 - 0
chrome-extension/src/background/agent/messages/utils.ts

@@ -0,0 +1,139 @@
+import { type BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage } from '@langchain/core/messages';
+
+export function removeThinkTags(text: string): string {
+  // Step 1: Remove well-formed <think>...</think>
+  const thinkTagsRegex = /<think>[\s\S]*?<\/think>/g;
+  let result = text.replace(thinkTagsRegex, '');
+
+  // Step 2: If there's an unmatched closing tag </think>,
+  // remove everything up to and including that.
+  const strayCloseTagRegex = /[\s\S]*?<\/think>/g;
+  result = result.replace(strayCloseTagRegex, '');
+
+  return result.trim();
+}
+
+/**
+ * Extract JSON from model output, handling both plain JSON and code-block-wrapped JSON.
+ * @param content - The string content that potentially contains JSON.
+ * @returns Parsed JSON object
+ * @throws Error if JSON parsing fails
+ */
+export function extractJsonFromModelOutput(content: string): Record<string, unknown> {
+  try {
+    let processedContent = content;
+
+    // If content is wrapped in code blocks, extract just the JSON part
+    if (processedContent.includes('```')) {
+      // Find the JSON content between code blocks
+      const parts = processedContent.split('```');
+      processedContent = parts[1];
+
+      // Remove language identifier if present (e.g., 'json\n')
+      if (processedContent.includes('\\n')) {
+        const newlineIndex = processedContent.indexOf('\\n');
+        processedContent = processedContent.substring(newlineIndex + 1);
+      }
+    }
+
+    // Parse the cleaned content
+    return JSON.parse(processedContent);
+  } catch (e) {
+    console.warn(`Failed to parse model output: ${content} ${e instanceof Error ? e.message : String(e)}`);
+    throw new Error('Could not parse response.');
+  }
+}
+
+/**
+ * Convert input messages to a format that is compatible with the planner model
+ * @param inputMessages - List of messages to convert
+ * @param modelName - Name of the model to convert messages for
+ * @returns Converted list of messages
+ */
+export function convertInputMessages(inputMessages: BaseMessage[], modelName: string | null): BaseMessage[] {
+  if (modelName === null) {
+    return inputMessages;
+  }
+  if (modelName === 'deepseek-reasoner' || modelName.includes('deepseek-r1')) {
+    const convertedInputMessages = convertMessagesForNonFunctionCallingModels(inputMessages);
+    let mergedInputMessages = mergeSuccessiveMessages(convertedInputMessages, HumanMessage);
+    mergedInputMessages = mergeSuccessiveMessages(mergedInputMessages, AIMessage);
+    return mergedInputMessages;
+  }
+  return inputMessages;
+}
+
+/**
+ * Convert messages for non-function-calling models
+ * @param inputMessages - List of messages to convert
+ * @returns Converted list of messages
+ */
+function convertMessagesForNonFunctionCallingModels(inputMessages: BaseMessage[]): BaseMessage[] {
+  const outputMessages: BaseMessage[] = [];
+
+  for (const message of inputMessages) {
+    if (message instanceof HumanMessage || message instanceof SystemMessage) {
+      outputMessages.push(message);
+    } else if (message instanceof ToolMessage) {
+      outputMessages.push(new HumanMessage({ content: message.content }));
+    } else if (message instanceof AIMessage) {
+      if (message.tool_calls) {
+        const toolCalls = JSON.stringify(message.tool_calls);
+        outputMessages.push(new AIMessage({ content: toolCalls }));
+      } else {
+        outputMessages.push(message);
+      }
+    } else {
+      throw new Error(`Unknown message type: ${message.constructor.name}`);
+    }
+  }
+
+  return outputMessages;
+}
+
+/**
+ * Merge successive messages of the same type into one message
+ * Some models like deepseek-reasoner don't allow multiple human messages in a row
+ * @param messages - List of messages to merge
+ * @param classToMerge - Message class type to merge
+ * @returns Merged list of messages
+ */
+function mergeSuccessiveMessages(
+  messages: BaseMessage[],
+  classToMerge: typeof HumanMessage | typeof AIMessage,
+): BaseMessage[] {
+  const mergedMessages: BaseMessage[] = [];
+  let streak = 0;
+
+  for (const message of messages) {
+    if (message instanceof classToMerge) {
+      streak += 1;
+      if (streak > 1) {
+        const lastMessage = mergedMessages[mergedMessages.length - 1];
+        if (Array.isArray(message.content)) {
+          // Handle array content case
+          if (typeof lastMessage.content === 'string') {
+            const textContent = message.content.find(
+              item => typeof item === 'object' && 'type' in item && item.type === 'text',
+            );
+            if (textContent && 'text' in textContent) {
+              lastMessage.content += textContent.text;
+            }
+          }
+        } else {
+          // Handle string content case
+          if (typeof lastMessage.content === 'string' && typeof message.content === 'string') {
+            lastMessage.content += message.content;
+          }
+        }
+      } else {
+        mergedMessages.push(message);
+      }
+    } else {
+      mergedMessages.push(message);
+      streak = 0;
+    }
+  }
+
+  return mergedMessages;
+}

+ 47 - 9
chrome-extension/src/background/agent/messages/views.ts

@@ -1,19 +1,30 @@
-import { type BaseMessage, HumanMessage } from '@langchain/core/messages';
+import { type BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
 
-export interface MessageMetadata {
-  inputTokens: number;
+export class MessageMetadata {
+  tokens: number;
+  message_type: string | null = null;
+
+  constructor(tokens: number, message_type?: string | null) {
+    this.tokens = tokens;
+    this.message_type = message_type ?? null;
+  }
 }
 
-export interface ManagedMessage {
+export class ManagedMessage {
   message: BaseMessage;
   metadata: MessageMetadata;
+
+  constructor(message: BaseMessage, metadata: MessageMetadata) {
+    this.message = message;
+    this.metadata = metadata;
+  }
 }
 
 export class MessageHistory {
   messages: ManagedMessage[] = [];
   totalTokens = 0;
 
-  addMessage(message: BaseMessage, metadata: MessageMetadata = { inputTokens: 0 }, position?: number): void {
+  addMessage(message: BaseMessage, metadata: MessageMetadata, position?: number): void {
     const managedMessage: ManagedMessage = {
       message,
       metadata,
@@ -24,13 +35,13 @@ export class MessageHistory {
     } else {
       this.messages.splice(position, 0, managedMessage);
     }
-    this.totalTokens += metadata.inputTokens;
+    this.totalTokens += metadata.tokens;
   }
 
   removeMessage(index = -1): void {
     if (this.messages.length > 0) {
       const msg = this.messages.splice(index, 1)[0];
-      this.totalTokens -= msg.metadata.inputTokens;
+      this.totalTokens -= msg.metadata.tokens;
     }
   }
 
@@ -38,11 +49,38 @@ export class MessageHistory {
    * Removes the last message from the history if it is a human message.
    * This is used to remove the state message from the history.
    */
-  removeLastHumanMessage(): void {
+  removeLastStateMessage(): void {
     if (this.messages.length > 2 && this.messages[this.messages.length - 1].message instanceof HumanMessage) {
       const msg = this.messages.pop();
       if (msg) {
-        this.totalTokens -= msg.metadata.inputTokens;
+        this.totalTokens -= msg.metadata.tokens;
+      }
+    }
+  }
+
+  /**
+   * Get all messages
+   */
+  getMessages(): BaseMessage[] {
+    return this.messages.map(m => m.message);
+  }
+
+  /**
+   * Get total tokens in history
+   */
+  getTotalTokens(): number {
+    return this.totalTokens;
+  }
+
+  /**
+   * Remove oldest non-system message
+   */
+  removeOldestMessage(): void {
+    for (let i = 0; i < this.messages.length; i++) {
+      if (!(this.messages[i].message instanceof SystemMessage)) {
+        const msg = this.messages.splice(i, 1)[0];
+        this.totalTokens -= msg.metadata.tokens;
+        break;
       }
     }
   }