Browse Source

merge Grok model support, refine helper methods for model settings

alexchenzl 5 months ago
parent
commit
dc93cb7717

+ 1 - 0
chrome-extension/package.json

@@ -23,6 +23,7 @@
     "@langchain/google-genai": "0.1.10",
     "@langchain/ollama": "^0.2.0",
     "@langchain/openai": "^0.4.2",
+    "@langchain/xai": "^0.0.2",
     "puppeteer-core": "24.1.1",
     "webextension-polyfill": "^0.12.0",
     "zod": "^3.24.1"

+ 2 - 0
chrome-extension/src/background/agent/agents/base.ts

@@ -84,6 +84,8 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
           return null;
         case 'ChatOpenAI':
         case 'AzureChatOpenAI':
+        case 'ChatGroq':
+        case 'ChatXAI':
           return 'function_calling';
         default:
           return null;

+ 15 - 5
chrome-extension/src/background/agent/helper.ts

@@ -2,6 +2,7 @@ import { type ProviderConfig, type ModelConfig, ProviderTypeEnum } from '@extens
 import { ChatOpenAI } from '@langchain/openai';
 import { ChatAnthropic } from '@langchain/anthropic';
 import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
+import { ChatXAI } from '@langchain/xai';
 import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
 import { ChatOllama } from '@langchain/ollama';
 
@@ -53,6 +54,9 @@ export function createChatModel(providerConfig: ProviderConfig, modelConfig: Mod
   const temperature = (modelConfig.parameters?.temperature ?? 0.1) as number;
   const topP = (modelConfig.parameters?.topP ?? 0.1) as number;
 
+  console.log('providerConfig', providerConfig);
+  console.log('modelConfig', modelConfig);
+
   switch (modelConfig.provider) {
     case ProviderTypeEnum.OpenAI: {
       return createOpenAIChatModel(providerConfig, modelConfig);
@@ -66,11 +70,6 @@ export function createChatModel(providerConfig: ProviderConfig, modelConfig: Mod
         topP,
         clientOptions: {},
       };
-      if (providerConfig.baseUrl) {
-        args.clientOptions = {
-          baseURL: providerConfig.baseUrl,
-        };
-      }
       return new ChatAnthropic(args);
     }
     case ProviderTypeEnum.Gemini: {
@@ -82,6 +81,17 @@ export function createChatModel(providerConfig: ProviderConfig, modelConfig: Mod
       };
       return new ChatGoogleGenerativeAI(args);
     }
+    case ProviderTypeEnum.Grok: {
+      const args = {
+        model: modelConfig.modelName,
+        apiKey: providerConfig.apiKey,
+        temperature,
+        topP,
+        maxTokens,
+        configuration: {},
+      };
+      return new ChatXAI(args) as BaseChatModel;
+    }
     case ProviderTypeEnum.Ollama: {
       const args: {
         model: string;

+ 20 - 4
chrome-extension/src/background/agent/messages/service.ts

@@ -337,13 +337,29 @@ export default class MessageManager {
         return message;
       }
       if (message instanceof ToolMessage) {
-        return new HumanMessage({ content: message.content });
+        return new HumanMessage({
+          content: `Tool Response: ${message.content}`,
+        });
       }
       if (message instanceof AIMessage) {
         // if it's an AIMessage with tool_calls, convert it to a normal AIMessage
-        if ('tool_calls' in message) {
-          const toolCalls = JSON.stringify(message.tool_calls);
-          return new AIMessage({ content: toolCalls });
+        if ('tool_calls' in message && message.tool_calls) {
+          const toolCallsStr = message.tool_calls
+            .map(tc => {
+              if (
+                'function' in tc &&
+                typeof tc.function === 'object' &&
+                tc.function &&
+                'name' in tc.function &&
+                'arguments' in tc.function
+              ) {
+                // For Groq, we need to format function calls differently
+                return `Function: ${tc.function.name}\nArguments: ${JSON.stringify(tc.function.arguments)}`;
+              }
+              return `Tool Call: ${JSON.stringify(tc)}`;
+            })
+            .join('\n');
+          return new AIMessage({ content: toolCallsStr });
         }
         return message;
       }

+ 20 - 18
packages/storage/lib/settings/llmProviders.ts

@@ -1,7 +1,7 @@
 import { StorageEnum } from '../base/enums';
 import { createStorage } from '../base/base';
 import type { BaseStorage } from '../base/types';
-import { llmProviderModelNames, ProviderTypeEnum } from './types';
+import { type AgentNameEnum, llmProviderModelNames, llmProviderParameters, ProviderTypeEnum } from './types';
 
 // Interface for a single provider configuration
 export interface ProviderConfig {
@@ -39,11 +39,13 @@ const storage = createStorage<LLMKeyRecord>(
 );
 
 // Helper function to determine provider type from provider name
+// Make sure to update this function if you add a new provider type
 export function getProviderTypeByProviderId(providerId: string): ProviderTypeEnum {
   switch (providerId) {
     case ProviderTypeEnum.OpenAI:
     case ProviderTypeEnum.Anthropic:
     case ProviderTypeEnum.Gemini:
+    case ProviderTypeEnum.Grok:
     case ProviderTypeEnum.Ollama:
       return providerId;
     default:
@@ -52,6 +54,7 @@ export function getProviderTypeByProviderId(providerId: string): ProviderTypeEnu
 }
 
 // Helper function to get display name from provider id
+// Make sure to update this function if you add a new provider type
 export function getDefaultDisplayNameFromProviderId(providerId: string): string {
   switch (providerId) {
     case ProviderTypeEnum.OpenAI:
@@ -60,6 +63,8 @@ export function getDefaultDisplayNameFromProviderId(providerId: string): string
       return 'Anthropic';
     case ProviderTypeEnum.Gemini:
       return 'Gemini';
+    case ProviderTypeEnum.Grok:
+      return 'Grok';
     case ProviderTypeEnum.Ollama:
       return 'Ollama';
     default:
@@ -68,32 +73,21 @@ export function getDefaultDisplayNameFromProviderId(providerId: string): string
 }
 
 // Get default configuration for built-in providers
+// Make sure to update this function if you add a new provider type
 export function getDefaultProviderConfig(providerId: string): ProviderConfig {
   switch (providerId) {
     case ProviderTypeEnum.OpenAI:
-      return {
-        apiKey: '',
-        name: getDefaultDisplayNameFromProviderId(ProviderTypeEnum.OpenAI),
-        type: ProviderTypeEnum.OpenAI,
-        modelNames: [...(llmProviderModelNames[ProviderTypeEnum.OpenAI] || [])],
-        createdAt: Date.now(),
-      };
     case ProviderTypeEnum.Anthropic:
-      return {
-        apiKey: '',
-        name: getDefaultDisplayNameFromProviderId(ProviderTypeEnum.Anthropic),
-        type: ProviderTypeEnum.Anthropic,
-        modelNames: [...(llmProviderModelNames[ProviderTypeEnum.Anthropic] || [])],
-        createdAt: Date.now(),
-      };
     case ProviderTypeEnum.Gemini:
+    case ProviderTypeEnum.Grok:
       return {
         apiKey: '',
-        name: getDefaultDisplayNameFromProviderId(ProviderTypeEnum.Gemini),
-        type: ProviderTypeEnum.Gemini,
-        modelNames: [...(llmProviderModelNames[ProviderTypeEnum.Gemini] || [])],
+        name: getDefaultDisplayNameFromProviderId(providerId),
+        type: providerId,
+        modelNames: [...(llmProviderModelNames[providerId] || [])],
         createdAt: Date.now(),
       };
+
     case ProviderTypeEnum.Ollama:
       return {
         apiKey: 'ollama', // Set default API key for Ollama
@@ -115,6 +109,14 @@ export function getDefaultProviderConfig(providerId: string): ProviderConfig {
   }
 }
 
+export function getDefaultAgentModelParams(providerId: string, agentName: AgentNameEnum): Record<string, number> {
+  const newParameters = llmProviderParameters[providerId as keyof typeof llmProviderParameters]?.[agentName] || {
+    temperature: 0.1,
+    topP: 0.1,
+  };
+  return newParameters;
+}
+
 // Helper function to ensure backward compatibility for provider configs
 function ensureBackwardCompatibility(providerId: string, config: ProviderConfig): ProviderConfig {
   const updatedConfig = { ...config };

+ 17 - 1
packages/storage/lib/settings/types.ts

@@ -12,13 +12,14 @@ export enum ProviderTypeEnum {
   OpenAI = 'openai',
   Anthropic = 'anthropic',
   Gemini = 'gemini',
+  Grok = 'grok',
   Ollama = 'ollama',
   CustomOpenAI = 'custom_openai',
 }
 
 // Default supported models for each built-in provider
 export const llmProviderModelNames = {
-  [ProviderTypeEnum.OpenAI]: ['gpt-4o', 'gpt-4o-mini', 'o1', 'o1-mini', 'o3-mini'],
+  [ProviderTypeEnum.OpenAI]: ['gpt-4o', 'gpt-4o-mini', 'o1', 'o3-mini'],
   [ProviderTypeEnum.Anthropic]: ['claude-3-7-sonnet-latest', 'claude-3-5-sonnet-latest', 'claude-3-5-haiku-latest'],
   [ProviderTypeEnum.Gemini]: [
     'gemini-2.0-flash',
@@ -26,6 +27,7 @@ export const llmProviderModelNames = {
     'gemini-2.0-pro-exp-02-05',
     // 'gemini-2.0-flash-thinking-exp-01-21', // TODO: not support function calling for now
   ],
+  [ProviderTypeEnum.Grok]: ['grok-2', 'grok-2-vision'],
   [ProviderTypeEnum.Ollama]: [],
   // Custom OpenAI providers don't have predefined models as they are user-defined
 };
@@ -74,6 +76,20 @@ export const llmProviderParameters = {
       topP: 0.1,
     },
   },
+  [ProviderTypeEnum.Grok]: {
+    [AgentNameEnum.Planner]: {
+      temperature: 0.7,
+      topP: 0.9,
+    },
+    [AgentNameEnum.Navigator]: {
+      temperature: 0.7,
+      topP: 0.9,
+    },
+    [AgentNameEnum.Validator]: {
+      temperature: 0.7,
+      topP: 0.9,
+    },
+  },
   [ProviderTypeEnum.Ollama]: {
     [AgentNameEnum.Planner]: {
       temperature: 0,

+ 2 - 5
pages/options/src/components/ModelSettings.tsx

@@ -7,9 +7,9 @@ import {
   AgentNameEnum,
   llmProviderModelNames,
   ProviderTypeEnum,
-  llmProviderParameters,
   getDefaultDisplayNameFromProviderId,
   getDefaultProviderConfig,
+  getDefaultAgentModelParams,
 } from '@extension/storage';
 
 interface ModelSettingsProps {
@@ -443,10 +443,7 @@ export const ModelSettings = ({ isDarkMode = false }: ModelSettingsProps) => {
     const [provider, model] = modelValue.split('>');
 
     // Set parameters based on provider type
-    const newParameters = llmProviderParameters[provider as keyof typeof llmProviderParameters]?.[agentName] || {
-      temperature: 0.1,
-      topP: 0.1,
-    };
+    const newParameters = getDefaultAgentModelParams(provider, agentName);
 
     setModelParameters(prev => ({
       ...prev,

+ 1 - 1
pages/side-panel/src/components/ChatInput.tsx

@@ -82,7 +82,7 @@ export default function ChatInput({
           onChange={handleTextChange}
           onKeyDown={handleKeyDown}
           disabled={disabled}
-          rows={4}
+          rows={5}
           className={`w-full resize-none border-none p-2 focus:outline-none ${
             disabled
               ? isDarkMode

+ 1 - 1
pages/side-panel/src/components/TemplateList.tsx

@@ -14,7 +14,7 @@ interface TemplateListProps {
 const TemplateList: React.FC<TemplateListProps> = ({ templates, onTemplateSelect, isDarkMode = false }) => {
   return (
     <div className="p-2">
-      <h3 className={`mb-3 text-sm font-medium ${isDarkMode ? 'text-gray-200' : 'text-gray-700'}`}>Templates</h3>
+      <h3 className={`mb-3 text-sm font-medium ${isDarkMode ? 'text-gray-200' : 'text-gray-700'}`}>Quick Start</h3>
       <div className="grid grid-cols-1 gap-3 sm:grid-cols-2">
         {templates.map(template => (
           <button

+ 59 - 0
pnpm-lock.yaml

@@ -126,6 +126,9 @@ importers:
       '@langchain/openai':
         specifier: ^0.4.2
         version: 0.4.2(@langchain/core@0.3.37(openai@4.82.0(ws@8.18.0)(zod@3.24.1)))(ws@8.18.0)
+      '@langchain/xai':
+        specifier: ^0.0.2
+        version: 0.0.2(@langchain/core@0.3.37(openai@4.82.0(ws@8.18.0)(zod@3.24.1)))(ws@8.18.0)
       puppeteer-core:
         specifier: 24.1.1
         version: 24.1.1
@@ -765,6 +768,18 @@ packages:
     peerDependencies:
       '@langchain/core': '>=0.3.29 <0.4.0'
 
+  '@langchain/openai@0.4.4':
+    resolution: {integrity: sha512-UZybJeMd8+UX7Kn47kuFYfqKdBCeBUWNqDtmAr6ZUIMMnlsNIb6MkrEEhGgAEjGCpdT4CU8U/DyyddTz+JayOQ==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      '@langchain/core': '>=0.3.39 <0.4.0'
+
+  '@langchain/xai@0.0.2':
+    resolution: {integrity: sha512-wVOs7SfJs4VWk/oiHJomaoaZ+r9nQhPqbEXlQ2D8L0d54PxYhb1ILR9rub9LT1RpqazSX8HG4A8+hX4R01qkSg==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      '@langchain/core': '>=0.2.21 <0.4.0'
+
   '@laynezh/vite-plugin-lib-assets@0.6.1':
     resolution: {integrity: sha512-pdIRW/PiJkuM7/OObjGBGfQmsWetmVObeez6uwT3nhP5cu2zT0L5QELq69caWD/v3QlPY3CPXVN0kZrzQzdvsQ==}
     peerDependencies:
@@ -3400,6 +3415,9 @@ packages:
   zod@3.24.1:
     resolution: {integrity: sha512-muH7gBL9sI1nciMZV67X5fTKKBLtwpZ5VBp1vsOQzj1MhrBZ4wlVCm3gedKZWLp0Oyel8sIGfeiz54Su+OVT+A==}
 
+  zod@3.24.2:
+    resolution: {integrity: sha512-lY7CDW43ECgW9u1TcT3IoXHflywfVqDYze4waEz812jR/bZ8FHDsl7pFQoSZTz5N+2NqRXs8GBwnAwo3ZNxqhQ==}
+
 snapshots:
 
   '@alloc/quick-lru@5.2.0': {}
@@ -3693,6 +3711,26 @@ snapshots:
       - encoding
       - ws
 
+  '@langchain/openai@0.4.4(@langchain/core@0.3.37(openai@4.82.0(ws@8.18.0)(zod@3.24.1)))(ws@8.18.0)':
+    dependencies:
+      '@langchain/core': 0.3.37(openai@4.82.0(ws@8.18.0)(zod@3.24.1))
+      js-tiktoken: 1.0.17
+      openai: 4.82.0(ws@8.18.0)(zod@3.24.2)
+      zod: 3.24.2
+      zod-to-json-schema: 3.24.1(zod@3.24.2)
+    transitivePeerDependencies:
+      - encoding
+      - ws
+
+  '@langchain/xai@0.0.2(@langchain/core@0.3.37(openai@4.82.0(ws@8.18.0)(zod@3.24.1)))(ws@8.18.0)':
+    dependencies:
+      '@langchain/core': 0.3.37(openai@4.82.0(ws@8.18.0)(zod@3.24.1))
+      '@langchain/openai': 0.4.4(@langchain/core@0.3.37(openai@4.82.0(ws@8.18.0)(zod@3.24.1)))(ws@8.18.0)
+      zod: 3.24.2
+    transitivePeerDependencies:
+      - encoding
+      - ws
+
   '@laynezh/vite-plugin-lib-assets@0.6.1(vite@6.0.5(@types/node@22.7.4)(jiti@1.21.6)(sass@1.79.4)(terser@5.34.1)(tsx@4.19.2)(yaml@2.5.1))':
     dependencies:
       escape-string-regexp: 4.0.0
@@ -5655,6 +5693,21 @@ snapshots:
     transitivePeerDependencies:
       - encoding
 
+  openai@4.82.0(ws@8.18.0)(zod@3.24.2):
+    dependencies:
+      '@types/node': 18.19.74
+      '@types/node-fetch': 2.6.12
+      abort-controller: 3.0.0
+      agentkeepalive: 4.6.0
+      form-data-encoder: 1.7.2
+      formdata-node: 4.4.1
+      node-fetch: 2.7.0
+    optionalDependencies:
+      ws: 8.18.0
+      zod: 3.24.2
+    transitivePeerDependencies:
+      - encoding
+
   optionator@0.9.4:
     dependencies:
       deep-is: 0.1.4
@@ -6617,4 +6670,10 @@ snapshots:
     dependencies:
       zod: 3.24.1
 
+  zod-to-json-schema@3.24.1(zod@3.24.2):
+    dependencies:
+      zod: 3.24.2
+
   zod@3.24.1: {}
+
+  zod@3.24.2: {}