navigator.ts 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import { z } from 'zod';
  2. import { BaseAgent, type BaseAgentOptions, type ExtraAgentOptions } from './base';
  3. import { createLogger } from '@src/background/log';
  4. import { ActionResult, type AgentOutput } from '../types';
  5. import type { Action } from '../actions/builder';
  6. import { buildDynamicActionSchema } from '../actions/builder';
  7. import { agentBrainSchema } from '../types';
  8. import { type BaseMessage, HumanMessage } from '@langchain/core/messages';
  9. import { Actors, ExecutionState } from '../event/types';
  10. import { isAuthenticationError } from '@src/background/utils';
  11. import { ChatModelAuthError } from './errors';
  12. import { jsonNavigatorOutputSchema } from '../actions/json_schema';
  13. import { geminiNavigatorOutputSchema } from '../actions/json_gemini';
  14. const logger = createLogger('NavigatorAgent');
  15. export class NavigatorActionRegistry {
  16. private actions: Record<string, Action> = {};
  17. constructor(actions: Action[]) {
  18. for (const action of actions) {
  19. this.registerAction(action);
  20. }
  21. }
  22. registerAction(action: Action): void {
  23. this.actions[action.name()] = action;
  24. }
  25. unregisterAction(name: string): void {
  26. delete this.actions[name];
  27. }
  28. getAction(name: string): Action | undefined {
  29. return this.actions[name];
  30. }
  31. setupModelOutputSchema(): z.ZodType {
  32. const actionSchema = buildDynamicActionSchema(Object.values(this.actions));
  33. return z.object({
  34. current_state: agentBrainSchema,
  35. action: z.array(actionSchema),
  36. });
  37. }
  38. }
  39. export interface NavigatorResult {
  40. done: boolean;
  41. }
  42. export class NavigatorAgent extends BaseAgent<z.ZodType, NavigatorResult> {
  43. private actionRegistry: NavigatorActionRegistry;
  44. private jsonSchema: Record<string, unknown>;
  45. constructor(
  46. actionRegistry: NavigatorActionRegistry,
  47. options: BaseAgentOptions,
  48. extraOptions?: Partial<ExtraAgentOptions>,
  49. ) {
  50. super(actionRegistry.setupModelOutputSchema(), options, { ...extraOptions, id: 'navigator' });
  51. this.actionRegistry = actionRegistry;
  52. this.jsonSchema = this.modelName.startsWith('gemini') ? geminiNavigatorOutputSchema : jsonNavigatorOutputSchema;
  53. // logger.info('Navigator zod schema', JSON.stringify(zodToJsonSchema(this.modelOutputSchema), null, 2));
  54. }
  55. async invoke(inputMessages: BaseMessage[]): Promise<this['ModelOutput']> {
  56. // Use structured output
  57. if (this.withStructuredOutput) {
  58. const structuredLlm = this.chatLLM.withStructuredOutput(this.jsonSchema, {
  59. includeRaw: true,
  60. name: this.modelOutputToolName,
  61. });
  62. let response = undefined;
  63. try {
  64. response = await structuredLlm.invoke(inputMessages, {
  65. ...this.callOptions,
  66. });
  67. if (response.parsed) {
  68. return response.parsed;
  69. }
  70. } catch (error) {
  71. const errorMessage = `Failed to invoke ${this.modelName} with structured output: ${error}`;
  72. throw new Error(errorMessage);
  73. }
  74. // Use type assertion to access the properties
  75. const rawResponse = response.raw as BaseMessage & {
  76. tool_calls?: Array<{
  77. args: {
  78. currentState: typeof agentBrainSchema._type;
  79. action: z.infer<ReturnType<typeof buildDynamicActionSchema>>;
  80. };
  81. }>;
  82. };
  83. // sometimes LLM returns an empty content, but with one or more tool calls, so we need to check the tool calls
  84. if (rawResponse.tool_calls && rawResponse.tool_calls.length > 0) {
  85. logger.info('Navigator structuredLlm tool call with empty content', rawResponse.tool_calls);
  86. // only use the first tool call
  87. const toolCall = rawResponse.tool_calls[0];
  88. return {
  89. current_state: toolCall.args.currentState,
  90. action: [...toolCall.args.action],
  91. };
  92. }
  93. throw new Error('Could not parse response');
  94. }
  95. throw new Error('Navigator needs to work with LLM that supports tool calling');
  96. }
  97. async execute(): Promise<AgentOutput<NavigatorResult>> {
  98. const agentOutput: AgentOutput<NavigatorResult> = {
  99. id: this.id,
  100. };
  101. let cancelled = false;
  102. try {
  103. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_START, 'Navigating...');
  104. const messageManager = this.context.messageManager;
  105. // add the browser state message
  106. await this.addStateMessageToMemory();
  107. // check if the task is paused or stopped
  108. if (this.context.paused || this.context.stopped) {
  109. cancelled = true;
  110. return agentOutput;
  111. }
  112. // call the model to get the actions to take
  113. const inputMessages = messageManager.getMessages();
  114. const modelOutput = await this.invoke(inputMessages);
  115. // check if the task is paused or stopped
  116. if (this.context.paused || this.context.stopped) {
  117. cancelled = true;
  118. return agentOutput;
  119. }
  120. // remove the last state message from memory before adding the model output
  121. this.removeLastStateMessageFromMemory();
  122. this.addModelOutputToMemory(modelOutput);
  123. // take the actions
  124. const actionResults = await this.doMultiAction(modelOutput);
  125. this.context.actionResults = actionResults;
  126. // check if the task is paused or stopped
  127. if (this.context.paused || this.context.stopped) {
  128. cancelled = true;
  129. return agentOutput;
  130. }
  131. // emit event
  132. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_OK, 'Navigation done');
  133. let done = false;
  134. if (actionResults.length > 0 && actionResults[actionResults.length - 1].isDone) {
  135. done = true;
  136. }
  137. agentOutput.result = { done };
  138. return agentOutput;
  139. } catch (error) {
  140. this.removeLastStateMessageFromMemory();
  141. // Check if this is an authentication error
  142. if (isAuthenticationError(error)) {
  143. throw new ChatModelAuthError('Navigator API Authentication failed. Please verify your API key', error);
  144. }
  145. const errorMessage = error instanceof Error ? error.message : String(error);
  146. const errorString = `Navigation failed: ${errorMessage}`;
  147. logger.error(errorString);
  148. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_FAIL, errorString);
  149. agentOutput.error = errorMessage;
  150. return agentOutput;
  151. } finally {
  152. // if the task is cancelled, remove the last state message from memory and emit event
  153. if (cancelled) {
  154. this.removeLastStateMessageFromMemory();
  155. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.STEP_CANCEL, 'Navigation cancelled');
  156. }
  157. }
  158. }
  159. /**
  160. * Add the state message to the memory
  161. */
  162. public async addStateMessageToMemory() {
  163. if (this.context.stateMessageAdded) {
  164. return;
  165. }
  166. const messageManager = this.context.messageManager;
  167. const options = this.context.options;
  168. // Handle results that should be included in memory
  169. if (this.context.actionResults.length > 0) {
  170. let index = 0;
  171. for (const r of this.context.actionResults) {
  172. if (r.includeInMemory) {
  173. if (r.extractedContent) {
  174. const msg = new HumanMessage(`Action result: ${r.extractedContent}`);
  175. // logger.info('Adding action result to memory', msg.content);
  176. messageManager.addMessageWithTokens(msg);
  177. }
  178. if (r.error) {
  179. const msg = new HumanMessage(`Action error: ${r.error.toString().slice(-options.maxErrorLength)}`);
  180. logger.info('Adding action error to memory', msg.content);
  181. messageManager.addMessageWithTokens(msg);
  182. }
  183. // reset this action result to empty, we dont want to add it again in the state message
  184. this.context.actionResults[index] = new ActionResult();
  185. }
  186. index++;
  187. }
  188. }
  189. const state = await this.prompt.getUserMessage(this.context);
  190. messageManager.addStateMessage(state);
  191. this.context.stateMessageAdded = true;
  192. }
  193. /**
  194. * Remove the last state message from the memory
  195. */
  196. protected async removeLastStateMessageFromMemory() {
  197. if (!this.context.stateMessageAdded) return;
  198. const messageManager = this.context.messageManager;
  199. messageManager.removeLastStateMessage();
  200. this.context.stateMessageAdded = false;
  201. }
  202. private async doMultiAction(response: this['ModelOutput']): Promise<ActionResult[]> {
  203. const results: ActionResult[] = [];
  204. let errCount = 0;
  205. logger.info('Actions', response.action);
  206. // sometimes response.action is a string, but not an array as expected, so we need to parse it as an array
  207. let actions: Record<string, unknown>[] = [];
  208. if (Array.isArray(response.action)) {
  209. // if the item is null, skip it
  210. actions = response.action.filter((item: unknown) => item !== null);
  211. if (actions.length === 0) {
  212. logger.warning('No valid actions found', response.action);
  213. }
  214. } else if (typeof response.action === 'string') {
  215. try {
  216. logger.warning('Unexpected action format', response.action);
  217. // try to parse the action as an JSON object
  218. actions = JSON.parse(response.action);
  219. } catch (error) {
  220. logger.error('Invalid action format', response.action);
  221. throw new Error('Invalid action output format');
  222. }
  223. } else {
  224. // if the action is neither an array nor a string, it should be an object
  225. actions = [response.action];
  226. }
  227. for (const action of actions) {
  228. const actionName = Object.keys(action)[0];
  229. const actionArgs = action[actionName];
  230. try {
  231. // check if the task is paused or stopped
  232. if (this.context.paused || this.context.stopped) {
  233. return results;
  234. }
  235. const result = await this.actionRegistry.getAction(actionName)?.call(actionArgs);
  236. if (result === undefined) {
  237. throw new Error(`Action ${actionName} not exists or returned undefined`);
  238. }
  239. results.push(result);
  240. // check if the task is paused or stopped
  241. if (this.context.paused || this.context.stopped) {
  242. return results;
  243. }
  244. // TODO: wait for 1 second for now, need to optimize this to avoid unnecessary waiting
  245. await new Promise(resolve => setTimeout(resolve, 1000));
  246. } catch (error) {
  247. const errorMessage = error instanceof Error ? error.message : String(error);
  248. logger.error('doAction error', actionName, actionArgs, errorMessage);
  249. // unexpected error, emit event
  250. this.context.emitEvent(Actors.NAVIGATOR, ExecutionState.ACT_FAIL, errorMessage);
  251. errCount++;
  252. if (errCount > 3) {
  253. throw new Error('Too many errors in actions');
  254. }
  255. results.push(
  256. new ActionResult({
  257. error: errorMessage,
  258. isDone: false,
  259. includeInMemory: true,
  260. }),
  261. );
  262. }
  263. }
  264. return results;
  265. }
  266. }