service.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import { type BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage } from '@langchain/core/messages';
  2. import { MessageHistory, type MessageMetadata, type ManagedMessage } from '@src/background/agent/messages/views';
  3. import { createLogger } from '@src/background/log';
  4. const logger = createLogger('MessageManager');
  5. export default class MessageManager {
  6. private maxInputTokens: number;
  7. private history: MessageHistory;
  8. private estimatedCharactersPerToken: number;
  9. private readonly IMG_TOKENS: number;
  10. private sensitiveData?: Record<string, string>;
  11. private toolId: number;
  12. constructor({
  13. maxInputTokens = 128000,
  14. estimatedCharactersPerToken = 3,
  15. imageTokens = 800,
  16. sensitiveData,
  17. }: {
  18. maxInputTokens?: number;
  19. estimatedCharactersPerToken?: number;
  20. imageTokens?: number;
  21. sensitiveData?: Record<string, string>;
  22. } = {}) {
  23. this.maxInputTokens = maxInputTokens;
  24. this.history = new MessageHistory();
  25. this.estimatedCharactersPerToken = estimatedCharactersPerToken;
  26. this.IMG_TOKENS = imageTokens;
  27. this.sensitiveData = sensitiveData;
  28. this.toolId = 1;
  29. }
  30. public initTaskMessages(systemMessage: SystemMessage, task: string, messageContext?: string): void {
  31. // Add system message
  32. this.addMessageWithTokens(systemMessage);
  33. // Add context message if provided
  34. if (messageContext && messageContext.length > 0) {
  35. const contextMessage = new HumanMessage({
  36. content: `Context for the task: ${messageContext}`,
  37. });
  38. this.addMessageWithTokens(contextMessage);
  39. }
  40. // Add task instructions
  41. const taskMessage = MessageManager.taskInstructions(task);
  42. this.addMessageWithTokens(taskMessage);
  43. // Add sensitive data info if sensitive data is provided
  44. if (this.sensitiveData) {
  45. const info = `Here are placeholders for sensitive data: ${Object.keys(this.sensitiveData)}`;
  46. const infoMessage = new HumanMessage({
  47. content: `${info}\nTo use them, write <secret>the placeholder name</secret>`,
  48. });
  49. this.addMessageWithTokens(infoMessage);
  50. }
  51. // Add example output
  52. const placeholderMessage = new HumanMessage({
  53. content: 'Example output:',
  54. });
  55. this.addMessageWithTokens(placeholderMessage);
  56. const toolCallId = this.nextToolId();
  57. const toolCalls = [
  58. {
  59. name: 'navigator_output',
  60. args: {
  61. current_state: {
  62. page_summary: 'On the page are company a,b,c wtih their revenue 1,2,3.',
  63. evaluation_previous_goal: 'Success - I opend the first page',
  64. memory: 'Starting with the new task. I have completed 1/10 steps',
  65. next_goal: 'Click on company a',
  66. },
  67. action: [{ click_element: { index: 0 } }],
  68. },
  69. id: String(toolCallId),
  70. type: 'tool_call' as const,
  71. },
  72. ];
  73. const exampleToolCall = new AIMessage({
  74. content: 'example tool call',
  75. tool_calls: toolCalls,
  76. });
  77. this.addMessageWithTokens(exampleToolCall);
  78. const toolMessage = new ToolMessage({
  79. content: 'Browser started',
  80. tool_call_id: String(toolCallId),
  81. });
  82. this.addMessageWithTokens(toolMessage);
  83. // Add history start marker
  84. const historyStartMessage = new HumanMessage({
  85. content: '[Your task history memory starts here]',
  86. });
  87. this.addMessageWithTokens(historyStartMessage);
  88. }
  89. public nextToolId(): number {
  90. const id = this.toolId;
  91. this.toolId += 1;
  92. return id;
  93. }
  94. /**
  95. * Createthe task instructions
  96. * @param task - The raw description of the task
  97. * @returns A HumanMessage object containing the task instructions
  98. */
  99. private static taskInstructions(task: string): HumanMessage {
  100. const content = `Your ultimate task is: """${task}""". If you achieved your ultimate task, stop everything and use the done action in the next step to complete the task. If not, continue as usual.`;
  101. return new HumanMessage({ content });
  102. }
  103. /**
  104. * Returns the number of messages in the history
  105. * @returns The number of messages in the history
  106. */
  107. public length(): number {
  108. return this.history.messages.length;
  109. }
  110. /**
  111. * Adds a new task to execute, it will be executed based on the history
  112. * @param newTask - The raw description of the new task
  113. */
  114. public addNewTask(newTask: string): void {
  115. const content = `Your new ultimate task is: """${newTask}""". Take the previous context into account and finish your new ultimate task. `;
  116. const msg = new HumanMessage({ content });
  117. this.addMessageWithTokens(msg);
  118. }
  119. /**
  120. * Adds a plan message to the history
  121. * @param plan - The raw description of the plan
  122. * @param position - The position to add the plan
  123. */
  124. public addPlan(plan?: string, position?: number): void {
  125. if (plan) {
  126. const msg = new AIMessage({ content: plan });
  127. this.addMessageWithTokens(msg, position);
  128. }
  129. }
  130. /**
  131. * Adds a state message to the history
  132. * @param stateMessage - The HumanMessage object containing the state
  133. */
  134. public addStateMessage(stateMessage: HumanMessage): void {
  135. this.addMessageWithTokens(stateMessage);
  136. }
  137. /**
  138. * Removes the last state message from the history
  139. */
  140. public removeLastStateMessage(): void {
  141. this.history.removeLastHumanMessage();
  142. }
  143. public getMessages(): BaseMessage[] {
  144. const messages = this.history.messages.map(m => m.message);
  145. let totalInputTokens = 0;
  146. logger.debug(`Messages in history: ${this.history.messages.length}:`);
  147. for (const m of this.history.messages) {
  148. totalInputTokens += m.metadata.inputTokens;
  149. logger.debug(`${m.message.constructor.name} - Token count: ${m.metadata.inputTokens}`);
  150. }
  151. logger.debug(`Total input tokens: ${totalInputTokens}`);
  152. return messages;
  153. }
  154. public getMessagesWithTokens(): ManagedMessage[] {
  155. return this.history.messages;
  156. }
  157. /**
  158. * Adds a message to the history with the token count metadata
  159. * @param message - The BaseMessage object to add
  160. * @param position - The optional position to add the message, if not provided, the message will be added to the end of the history
  161. */
  162. public addMessageWithTokens(message: BaseMessage, position?: number): void {
  163. let filteredMessage = message;
  164. // filter out sensitive data if provided
  165. if (this.sensitiveData) {
  166. filteredMessage = this._filterSensitiveData(message);
  167. }
  168. const tokenCount = this._countTokens(filteredMessage);
  169. const metadata: MessageMetadata = { inputTokens: tokenCount };
  170. this.history.addMessage(filteredMessage, metadata, position);
  171. }
  172. /**
  173. * Filters out sensitive data from the message
  174. * @param message - The BaseMessage object to filter
  175. * @returns The filtered BaseMessage object
  176. */
  177. private _filterSensitiveData(message: BaseMessage): BaseMessage {
  178. const replaceSensitive = (value: string): string => {
  179. let filteredValue = value;
  180. if (!this.sensitiveData) return filteredValue;
  181. for (const [key, val] of Object.entries(this.sensitiveData)) {
  182. filteredValue = filteredValue.replace(val, `<secret>${key}</secret>`);
  183. }
  184. return filteredValue;
  185. };
  186. if (typeof message.content === 'string') {
  187. message.content = replaceSensitive(message.content);
  188. } else if (Array.isArray(message.content)) {
  189. message.content = message.content.map(item => {
  190. if (typeof item === 'object' && 'text' in item) {
  191. return { ...item, text: replaceSensitive(item.text) };
  192. }
  193. return item;
  194. });
  195. }
  196. return message;
  197. }
  198. /**
  199. * Counts the tokens in the message
  200. * @param message - The BaseMessage object to count the tokens
  201. * @returns The number of tokens in the message
  202. */
  203. private _countTokens(message: BaseMessage): number {
  204. let tokens = 0;
  205. if (Array.isArray(message.content)) {
  206. for (const item of message.content) {
  207. if ('image_url' in item) {
  208. tokens += this.IMG_TOKENS;
  209. } else if (typeof item === 'object' && 'text' in item) {
  210. tokens += this._countTextTokens(item.text);
  211. }
  212. }
  213. } else {
  214. let msg = message.content;
  215. // Check if it's an AIMessage with tool_calls
  216. if ('tool_calls' in message) {
  217. msg += JSON.stringify(message.tool_calls);
  218. }
  219. tokens += this._countTextTokens(msg);
  220. }
  221. return tokens;
  222. }
  223. /**
  224. * Counts the tokens in the text
  225. * Rough estimate, no tokenizer provided for now
  226. * @param text - The text to count the tokens
  227. * @returns The number of tokens in the text
  228. */
  229. private _countTextTokens(text: string): number {
  230. return Math.floor(text.length / this.estimatedCharactersPerToken);
  231. }
  232. /**
  233. * Cuts the last message if the total tokens exceed the max input tokens
  234. *
  235. * Get current message list, potentially trimmed to max tokens
  236. */
  237. public cutMessages(): void {
  238. let diff = this.history.totalTokens - this.maxInputTokens;
  239. if (diff <= 0) return;
  240. const lastMsg = this.history.messages[this.history.messages.length - 1];
  241. // if list with image remove image
  242. if (Array.isArray(lastMsg.message.content)) {
  243. let text = '';
  244. lastMsg.message.content = lastMsg.message.content.filter(item => {
  245. if ('image_url' in item) {
  246. diff -= this.IMG_TOKENS;
  247. lastMsg.metadata.inputTokens -= this.IMG_TOKENS;
  248. this.history.totalTokens -= this.IMG_TOKENS;
  249. logger.debug(
  250. `Removed image with ${this.IMG_TOKENS} tokens - total tokens now: ${this.history.totalTokens}/${this.maxInputTokens}`,
  251. );
  252. return false;
  253. }
  254. if ('text' in item) {
  255. text += item.text;
  256. }
  257. return true;
  258. });
  259. lastMsg.message.content = text;
  260. this.history.messages[this.history.messages.length - 1] = lastMsg;
  261. }
  262. if (diff <= 0) return;
  263. // if still over, remove text from state message proportionally to the number of tokens needed with buffer
  264. // Calculate the proportion of content to remove
  265. const proportionToRemove = diff / lastMsg.metadata.inputTokens;
  266. if (proportionToRemove > 0.99) {
  267. throw new Error(
  268. `Max token limit reached - history is too long - reduce the system prompt or task. proportion_to_remove: ${proportionToRemove}`,
  269. );
  270. }
  271. logger.debug(
  272. `Removing ${(proportionToRemove * 100).toFixed(2)}% of the last message (${(proportionToRemove * lastMsg.metadata.inputTokens).toFixed(2)} / ${lastMsg.metadata.inputTokens.toFixed(2)} tokens)`,
  273. );
  274. const content = lastMsg.message.content as string;
  275. const charactersToRemove = Math.floor(content.length * proportionToRemove);
  276. const newContent = content.slice(0, -charactersToRemove);
  277. this.history.removeMessage(-1);
  278. const msg = new HumanMessage({ content: newContent });
  279. this.addMessageWithTokens(msg);
  280. const finalMsg = this.history.messages[this.history.messages.length - 1];
  281. logger.debug(
  282. `Added message with ${finalMsg.metadata.inputTokens} tokens - total tokens now: ${this.history.totalTokens}/${this.maxInputTokens} - total messages: ${this.history.messages.length}`,
  283. );
  284. }
  285. /**
  286. * Converts messages for non-function-calling models
  287. * @param inputMessages - The BaseMessage objects to convert
  288. * @returns The converted BaseMessage objects
  289. */
  290. public convertMessagesForNonFunctionCallingModels(inputMessages: BaseMessage[]): BaseMessage[] {
  291. return inputMessages.map(message => {
  292. if (message instanceof HumanMessage || message instanceof SystemMessage) {
  293. return message;
  294. }
  295. if (message instanceof ToolMessage) {
  296. return new HumanMessage({
  297. content: `Tool Response: ${message.content}`,
  298. });
  299. }
  300. if (message instanceof AIMessage) {
  301. // if it's an AIMessage with tool_calls, convert it to a normal AIMessage
  302. if ('tool_calls' in message && message.tool_calls) {
  303. const toolCallsStr = message.tool_calls
  304. .map(tc => {
  305. if (
  306. 'function' in tc &&
  307. typeof tc.function === 'object' &&
  308. tc.function &&
  309. 'name' in tc.function &&
  310. 'arguments' in tc.function
  311. ) {
  312. // For Groq, we need to format function calls differently
  313. return `Function: ${tc.function.name}\nArguments: ${JSON.stringify(tc.function.arguments)}`;
  314. }
  315. return `Tool Call: ${JSON.stringify(tc)}`;
  316. })
  317. .join('\n');
  318. return new AIMessage({ content: toolCallsStr });
  319. }
  320. return message;
  321. }
  322. throw new Error(`Unknown message type: ${message.constructor.name}`);
  323. });
  324. }
  325. /**
  326. * Some models like deepseek-reasoner dont allow multiple human messages in a row. This function merges them into one."
  327. * @param messages - The BaseMessage objects to merge
  328. * @param classToMerge - The class of the messages to merge
  329. * @returns The merged BaseMessage objects
  330. */
  331. public mergeSuccessiveMessages(messages: BaseMessage[], classToMerge: typeof BaseMessage): BaseMessage[] {
  332. const mergedMessages: BaseMessage[] = [];
  333. let streak = 0;
  334. for (const message of messages) {
  335. if (message instanceof classToMerge) {
  336. streak += 1;
  337. if (streak > 1) {
  338. const lastMessage = mergedMessages[mergedMessages.length - 1];
  339. if (Array.isArray(message.content)) {
  340. const firstContent = message.content[0];
  341. if ('text' in firstContent) {
  342. lastMessage.content += firstContent.text;
  343. }
  344. } else {
  345. lastMessage.content += message.content;
  346. }
  347. } else {
  348. mergedMessages.push(message);
  349. }
  350. } else {
  351. mergedMessages.push(message);
  352. streak = 0;
  353. }
  354. }
  355. return mergedMessages;
  356. }
  357. }