navigator.ts 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import { z } from 'zod';
  2. import { BaseAgent, type BaseAgentOptions, type ExtraAgentOptions } from './base';
  3. import { createLogger } from '@src/background/log';
  4. import { ActionResult, type AgentOutput } from '../types';
  5. import type { Action } from '../actions/builder';
  6. import { buildDynamicActionSchema } from '../actions/builder';
  7. import { agentBrainSchema } from '../types';
  8. import { type BaseMessage, HumanMessage } from '@langchain/core/messages';
  9. import { Actors, ExecutionState } from '../event/types';
  10. import {
  11. ChatModelAuthError,
  12. ChatModelForbiddenError,
  13. isAuthenticationError,
  14. isForbiddenError,
  15. LLM_FORBIDDEN_ERROR_MESSAGE,
  16. } from './errors';
  17. import { jsonNavigatorOutputSchema } from '../actions/json_schema';
  18. import { geminiNavigatorOutputSchema } from '../actions/json_gemini';
  19. import { calcBranchPathHashSet } from '@src/background/dom/views';
  20. const logger = createLogger('NavigatorAgent');
  21. export class NavigatorActionRegistry {
  22. private actions: Record<string, Action> = {};
  23. constructor(actions: Action[]) {
  24. for (const action of actions) {
  25. this.registerAction(action);
  26. }
  27. }
  28. registerAction(action: Action): void {
  29. this.actions[action.name()] = action;
  30. }
  31. unregisterAction(name: string): void {
  32. delete this.actions[name];
  33. }
  34. getAction(name: string): Action | undefined {
  35. return this.actions[name];
  36. }
  37. setupModelOutputSchema(): z.ZodType {
  38. const actionSchema = buildDynamicActionSchema(Object.values(this.actions));
  39. return z.object({
  40. current_state: agentBrainSchema,
  41. action: z.array(actionSchema),
  42. });
  43. }
  44. }
  45. export interface NavigatorResult {
  46. done: boolean;
  47. }
  48. export class NavigatorAgent extends BaseAgent<z.ZodType, NavigatorResult> {
  49. private actionRegistry: NavigatorActionRegistry;
  50. private jsonSchema: Record<string, unknown>;
  51. constructor(
  52. actionRegistry: NavigatorActionRegistry,
  53. options: BaseAgentOptions,
  54. extraOptions?: Partial<ExtraAgentOptions>,
  55. ) {
  56. super(actionRegistry.setupModelOutputSchema(), options, { ...extraOptions, id: 'navigator' });
  57. this.actionRegistry = actionRegistry;
  58. this.jsonSchema = this.modelName.startsWith('gemini') ? geminiNavigatorOutputSchema : jsonNavigatorOutputSchema;
  59. // logger.info('Navigator zod schema', JSON.stringify(zodToJsonSchema(this.modelOutputSchema), null, 2));
  60. }
  61. async invoke(inputMessages: BaseMessage[]): Promise<this['ModelOutput']> {
  62. // Use structured output
  63. if (this.withStructuredOutput) {
  64. const structuredLlm = this.chatLLM.withStructuredOutput(this.jsonSchema, {
  65. includeRaw: true,
  66. name: this.modelOutputToolName,
  67. });
  68. let response = undefined;
  69. try {
  70. response = await structuredLlm.invoke(inputMessages, {
  71. ...this.callOptions,
  72. });
  73. if (response.parsed) {
  74. return response.parsed;
  75. }
  76. } catch (error) {
  77. const errorMessage = `Failed to invoke ${this.modelName} with structured output: ${error}`;
  78. throw new Error(errorMessage);
  79. }
  80. // Use type assertion to access the properties
  81. const rawResponse = response.raw as BaseMessage & {
  82. tool_calls?: Array<{
  83. args: {
  84. currentState: typeof agentBrainSchema._type;
  85. action: z.infer<ReturnType<typeof buildDynamicActionSchema>>;
  86. };
  87. }>;
  88. };
  89. // sometimes LLM returns an empty content, but with one or more tool calls, so we need to check the tool calls
  90. if (rawResponse.tool_calls && rawResponse.tool_calls.length > 0) {
  91. logger.info('Navigator structuredLlm tool call with empty content', rawResponse.tool_calls);
  92. // only use the first tool call
  93. const toolCall = rawResponse.tool_calls[0];
  94. return {
  95. current_state: toolCall.args.currentState,
  96. action: [...toolCall.args.action],
  97. };
  98. }
  99. throw new Error('Could not parse response');
  100. }
  101. throw new Error('Navigator needs to work with LLM that supports tool calling');
  102. }
  103. async execute(): Promise<AgentOutput<NavigatorResult>> {
  104. const agentOutput: AgentOutput<NavigatorResult> = {
  105. id: this.id,
  106. };
  107. let cancelled = false;
  108. try {
  109. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_START, 'Navigating...');
  110. const messageManager = this.context.messageManager;
  111. // add the browser state message
  112. await this.addStateMessageToMemory();
  113. // check if the task is paused or stopped
  114. if (this.context.paused || this.context.stopped) {
  115. cancelled = true;
  116. return agentOutput;
  117. }
  118. // call the model to get the actions to take
  119. const inputMessages = messageManager.getMessages();
  120. const modelOutput = await this.invoke(inputMessages);
  121. // check if the task is paused or stopped
  122. if (this.context.paused || this.context.stopped) {
  123. cancelled = true;
  124. return agentOutput;
  125. }
  126. // remove the last state message from memory before adding the model output
  127. this.removeLastStateMessageFromMemory();
  128. this.addModelOutputToMemory(modelOutput);
  129. // take the actions
  130. const actionResults = await this.doMultiAction(modelOutput);
  131. this.context.actionResults = actionResults;
  132. // check if the task is paused or stopped
  133. if (this.context.paused || this.context.stopped) {
  134. cancelled = true;
  135. return agentOutput;
  136. }
  137. // emit event
  138. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_OK, 'Navigation done');
  139. let done = false;
  140. if (actionResults.length > 0 && actionResults[actionResults.length - 1].isDone) {
  141. done = true;
  142. }
  143. agentOutput.result = { done };
  144. return agentOutput;
  145. } catch (error) {
  146. this.removeLastStateMessageFromMemory();
  147. // Check if this is an authentication error
  148. if (isAuthenticationError(error)) {
  149. throw new ChatModelAuthError('Navigator API Authentication failed. Please verify your API key', error);
  150. }
  151. if (isForbiddenError(error)) {
  152. throw new ChatModelForbiddenError(LLM_FORBIDDEN_ERROR_MESSAGE, error);
  153. }
  154. const errorMessage = error instanceof Error ? error.message : String(error);
  155. const errorString = `Navigation failed: ${errorMessage}`;
  156. logger.error(errorString);
  157. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_FAIL, errorString);
  158. agentOutput.error = errorMessage;
  159. return agentOutput;
  160. } finally {
  161. // if the task is cancelled, remove the last state message from memory and emit event
  162. if (cancelled) {
  163. this.removeLastStateMessageFromMemory();
  164. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_CANCEL, 'Navigation cancelled');
  165. }
  166. }
  167. }
  168. /**
  169. * Add the state message to the memory
  170. */
  171. public async addStateMessageToMemory() {
  172. if (this.context.stateMessageAdded) {
  173. return;
  174. }
  175. const messageManager = this.context.messageManager;
  176. // Handle results that should be included in memory
  177. if (this.context.actionResults.length > 0) {
  178. let index = 0;
  179. for (const r of this.context.actionResults) {
  180. if (r.includeInMemory) {
  181. if (r.extractedContent) {
  182. const msg = new HumanMessage(`Action result: ${r.extractedContent}`);
  183. // logger.info('Adding action result to memory', msg.content);
  184. messageManager.addMessageWithTokens(msg);
  185. }
  186. if (r.error) {
  187. // Get error text and convert to string
  188. const errorText = r.error.toString().trim();
  189. // Get only the last line of the error
  190. const lastLine = errorText.split('\n').pop() || '';
  191. const msg = new HumanMessage(`Action error: ${lastLine}`);
  192. logger.info('Adding action error to memory', msg.content);
  193. messageManager.addMessageWithTokens(msg);
  194. }
  195. // reset this action result to empty, we dont want to add it again in the state message
  196. // NOTE: in python version, all action results are reset to empty, but in ts version, only those included in memory are reset to empty
  197. this.context.actionResults[index] = new ActionResult();
  198. }
  199. index++;
  200. }
  201. }
  202. const state = await this.prompt.getUserMessage(this.context);
  203. messageManager.addStateMessage(state);
  204. this.context.stateMessageAdded = true;
  205. }
  206. /**
  207. * Remove the last state message from the memory
  208. */
  209. protected async removeLastStateMessageFromMemory() {
  210. if (!this.context.stateMessageAdded) return;
  211. const messageManager = this.context.messageManager;
  212. messageManager.removeLastStateMessage();
  213. this.context.stateMessageAdded = false;
  214. }
  215. private async addModelOutputToMemory(modelOutput: this['ModelOutput']) {
  216. const messageManager = this.context.messageManager;
  217. messageManager.addModelOutput(modelOutput);
  218. }
  219. private async doMultiAction(response: this['ModelOutput']): Promise<ActionResult[]> {
  220. const results: ActionResult[] = [];
  221. let errCount = 0;
  222. logger.info('Actions', response.action);
  223. // sometimes response.action is a string, but not an array as expected, so we need to parse it as an array
  224. let actions: Record<string, unknown>[] = [];
  225. if (Array.isArray(response.action)) {
  226. // if the item is null, skip it
  227. actions = response.action.filter((item: unknown) => item !== null);
  228. if (actions.length === 0) {
  229. logger.warning('No valid actions found', response.action);
  230. }
  231. } else if (typeof response.action === 'string') {
  232. try {
  233. logger.warning('Unexpected action format', response.action);
  234. // try to parse the action as an JSON object
  235. actions = JSON.parse(response.action);
  236. } catch (error) {
  237. logger.error('Invalid action format', response.action);
  238. throw new Error('Invalid action output format');
  239. }
  240. } else {
  241. // if the action is neither an array nor a string, it should be an object
  242. actions = [response.action];
  243. }
  244. const browserContext = this.context.browserContext;
  245. const browserState = await browserContext.getState();
  246. const cachedPathHashes = await calcBranchPathHashSet(browserState);
  247. await browserContext.removeHighlight();
  248. for (const [i, action] of actions.entries()) {
  249. const actionName = Object.keys(action)[0];
  250. const actionArgs = action[actionName];
  251. try {
  252. // check if the task is paused or stopped
  253. if (this.context.paused || this.context.stopped) {
  254. return results;
  255. }
  256. const actionInstance = this.actionRegistry.getAction(actionName);
  257. if (actionInstance === undefined) {
  258. throw new Error(`Action ${actionName} not exists`);
  259. }
  260. const indexArg = actionInstance.getIndexArg(actionArgs);
  261. if (i > 0 && indexArg !== null) {
  262. const newState = await browserContext.getState();
  263. const newPathHashes = await calcBranchPathHashSet(newState);
  264. // next action requires index but there are new elements on the page
  265. if (!newPathHashes.isSubsetOf(cachedPathHashes)) {
  266. const msg = `Something new appeared after action ${i} / ${actions.length}`;
  267. logger.info(msg);
  268. results.push(
  269. new ActionResult({
  270. extractedContent: msg,
  271. includeInMemory: true,
  272. }),
  273. );
  274. break;
  275. }
  276. }
  277. const result = await actionInstance.call(actionArgs);
  278. if (result === undefined) {
  279. throw new Error(`Action ${actionName} returned undefined`);
  280. }
  281. results.push(result);
  282. // check if the task is paused or stopped
  283. if (this.context.paused || this.context.stopped) {
  284. return results;
  285. }
  286. // TODO: wait for 1 second for now, need to optimize this to avoid unnecessary waiting
  287. await new Promise(resolve => setTimeout(resolve, 1000));
  288. } catch (error) {
  289. const errorMessage = error instanceof Error ? error.message : String(error);
  290. logger.error('doAction error', actionName, actionArgs, errorMessage);
  291. // unexpected error, emit event
  292. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.ACT_FAIL, errorMessage);
  293. errCount++;
  294. if (errCount > 3) {
  295. throw new Error('Too many errors in actions');
  296. }
  297. results.push(
  298. new ActionResult({
  299. error: errorMessage,
  300. isDone: false,
  301. includeInMemory: true,
  302. }),
  303. );
  304. }
  305. }
  306. return results;
  307. }
  308. }