Przeglądaj źródła

Merge pull request #49 from hataiit9x/groq_grok_ai

feat: Add Grok and Groq provider support
Ashu 5 miesięcy temu
rodzic
commit
a6036867fe

+ 2 - 0
chrome-extension/package.json

@@ -21,7 +21,9 @@
     "@langchain/anthropic": "^0.3.12",
     "@langchain/anthropic": "^0.3.12",
     "@langchain/core": "^0.3.37",
     "@langchain/core": "^0.3.37",
     "@langchain/google-genai": "0.1.10",
     "@langchain/google-genai": "0.1.10",
+    "@langchain/groq": "^0.1.3",
     "@langchain/openai": "^0.4.2",
     "@langchain/openai": "^0.4.2",
+    "@langchain/xai": "^0.0.2",
     "puppeteer-core": "24.1.1",
     "puppeteer-core": "24.1.1",
     "webextension-polyfill": "^0.12.0",
     "webextension-polyfill": "^0.12.0",
     "zod": "^3.24.1"
     "zod": "^3.24.1"

+ 2 - 0
chrome-extension/src/background/agent/agents/base.ts

@@ -84,6 +84,8 @@ export abstract class BaseAgent<T extends z.ZodType, M = unknown> {
           return null;
           return null;
         case 'ChatOpenAI':
         case 'ChatOpenAI':
         case 'AzureChatOpenAI':
         case 'AzureChatOpenAI':
+        case 'ChatGroq':
+        case 'ChatXAI':
           return 'function_calling';
           return 'function_calling';
         default:
         default:
           return null;
           return null;

+ 40 - 0
chrome-extension/src/background/agent/helper.ts

@@ -2,6 +2,8 @@ import { type ProviderConfig, LLMProviderEnum, AgentNameEnum } from '@extension/
 import { ChatOpenAI } from '@langchain/openai';
 import { ChatOpenAI } from '@langchain/openai';
 import { ChatAnthropic } from '@langchain/anthropic';
 import { ChatAnthropic } from '@langchain/anthropic';
 import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
 import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
+import { ChatGroq } from '@langchain/groq';
+import { ChatXAI } from '@langchain/xai';
 import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
 import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
 
 
 // create a chat model based on the agent name, the model name and provider
 // create a chat model based on the agent name, the model name and provider
@@ -72,6 +74,44 @@ export function createChatModel(
       };
       };
       return new ChatGoogleGenerativeAI(args);
       return new ChatGoogleGenerativeAI(args);
     }
     }
+    case LLMProviderEnum.Groq: {
+      temperature = 0.7;
+      const args: any = {
+        model: modelName,
+        apiKey: providerConfig.apiKey,
+        temperature,
+        maxTokens,
+        configuration: {},
+        modelKwargs: {
+          stop: [],
+          stream: false,
+        },
+      };
+      if (providerConfig.baseUrl) {
+        args.configuration = {
+          baseURL: providerConfig.baseUrl,
+        };
+      }
+      return new ChatGroq(args);
+    }
+    case LLMProviderEnum.Grok: {
+      temperature = 0.7;
+      topP = 0.9;
+      const args: any = {
+        model: modelName,
+        apiKey: providerConfig.apiKey,
+        temperature,
+        topP,
+        maxTokens,
+        configuration: {},
+      };
+      if (providerConfig.baseUrl) {
+        args.configuration = {
+          baseURL: providerConfig.baseUrl,
+        };
+      }
+      return new ChatXAI(args);
+    }
     default: {
     default: {
       throw new Error(`Provider ${providerName} not supported yet`);
       throw new Error(`Provider ${providerName} not supported yet`);
     }
     }

+ 20 - 4
chrome-extension/src/background/agent/messages/service.ts

@@ -337,13 +337,29 @@ export default class MessageManager {
         return message;
         return message;
       }
       }
       if (message instanceof ToolMessage) {
       if (message instanceof ToolMessage) {
-        return new HumanMessage({ content: message.content });
+        return new HumanMessage({
+          content: `Tool Response: ${message.content}`,
+        });
       }
       }
       if (message instanceof AIMessage) {
       if (message instanceof AIMessage) {
         // if it's an AIMessage with tool_calls, convert it to a normal AIMessage
         // if it's an AIMessage with tool_calls, convert it to a normal AIMessage
-        if ('tool_calls' in message) {
-          const toolCalls = JSON.stringify(message.tool_calls);
-          return new AIMessage({ content: toolCalls });
+        if ('tool_calls' in message && message.tool_calls) {
+          const toolCallsStr = message.tool_calls
+            .map(tc => {
+              if (
+                'function' in tc &&
+                typeof tc.function === 'object' &&
+                tc.function &&
+                'name' in tc.function &&
+                'arguments' in tc.function
+              ) {
+                // For Groq, we need to format function calls differently
+                return `Function: ${tc.function.name}\nArguments: ${JSON.stringify(tc.function.arguments)}`;
+              }
+              return `Tool Call: ${JSON.stringify(tc)}`;
+            })
+            .join('\n');
+          return new AIMessage({ content: toolCallsStr });
         }
         }
         return message;
         return message;
       }
       }

+ 10 - 0
packages/storage/lib/settings/types.ts

@@ -9,6 +9,8 @@ export enum LLMProviderEnum {
   OpenAI = 'openai',
   OpenAI = 'openai',
   Anthropic = 'anthropic',
   Anthropic = 'anthropic',
   Gemini = 'gemini',
   Gemini = 'gemini',
+  Groq = 'groq',
+  Grok = 'grok',
 }
 }
 
 
 export const llmProviderModelNames = {
 export const llmProviderModelNames = {
@@ -20,6 +22,14 @@ export const llmProviderModelNames = {
     'gemini-2.0-pro-exp-02-05',
     'gemini-2.0-pro-exp-02-05',
     // 'gemini-2.0-flash-thinking-exp-01-21', // TODO: not support function calling for now
     // 'gemini-2.0-flash-thinking-exp-01-21', // TODO: not support function calling for now
   ],
   ],
+  [LLMProviderEnum.Groq]: [
+    'llama-3.1-8b-instant',
+    'mixtral-8x7b-32768',
+    'llama2-70b-4096',
+    'llama-2-70b-4096',
+    'gemma-7b-it',
+  ],
+  [LLMProviderEnum.Grok]: ['grok-2', 'grok-2-vision'],
 };
 };
 
 
 /**
 /**

+ 25 - 3
pages/options/src/components/ModelSettings.tsx

@@ -180,12 +180,28 @@ export const ModelSettings = ({ isDarkMode = false }: ModelSettingsProps) => {
 
 
   const renderApiKeyInput = (provider: LLMProviderEnum) => {
   const renderApiKeyInput = (provider: LLMProviderEnum) => {
     const buttonProps = getButtonProps(provider);
     const buttonProps = getButtonProps(provider);
-    const needsBaseUrl = provider === LLMProviderEnum.OpenAI || provider === LLMProviderEnum.Anthropic;
+    const needsBaseUrl =
+      provider === LLMProviderEnum.OpenAI ||
+      provider === LLMProviderEnum.Anthropic ||
+      provider === LLMProviderEnum.Groq ||
+      provider === LLMProviderEnum.Grok;
 
 
     return (
     return (
       <div key={provider} className="mb-6">
       <div key={provider} className="mb-6">
         <div className="mb-2 flex items-center justify-between">
         <div className="mb-2 flex items-center justify-between">
-          <h3 className={`text-lg font-medium ${isDarkMode ? 'text-gray-300' : 'text-gray-700'}`}>{provider}</h3>
+          <h3 className={`text-lg font-medium ${isDarkMode ? 'text-gray-300' : 'text-gray-700'}`}>
+            {provider === LLMProviderEnum.OpenAI
+              ? 'OpenAI'
+              : provider === LLMProviderEnum.Anthropic
+                ? 'Anthropic'
+                : provider === LLMProviderEnum.Gemini
+                  ? 'Gemini'
+                  : provider === LLMProviderEnum.Groq
+                    ? 'Groq AI'
+                    : provider === LLMProviderEnum.Grok
+                      ? 'Grok AI'
+                      : provider}
+          </h3>
           <div>
           <div>
             <Button
             <Button
               onClick={() => handleSave(provider)}
               onClick={() => handleSave(provider)}
@@ -232,7 +248,13 @@ export const ModelSettings = ({ isDarkMode = false }: ModelSettingsProps) => {
                 value={apiKeys[provider]?.baseUrl || ''}
                 value={apiKeys[provider]?.baseUrl || ''}
                 onChange={e => handleApiKeyChange(provider, apiKeys[provider]?.apiKey || '', e.target.value)}
                 onChange={e => handleApiKeyChange(provider, apiKeys[provider]?.apiKey || '', e.target.value)}
                 className={`w-full rounded-md border ${isDarkMode ? 'border-slate-600 bg-slate-700 text-gray-200' : 'border-gray-300 bg-white text-gray-700'} px-3 py-2`}
                 className={`w-full rounded-md border ${isDarkMode ? 'border-slate-600 bg-slate-700 text-gray-200' : 'border-gray-300 bg-white text-gray-700'} px-3 py-2`}
-                placeholder={`Enter custom base URL for ${provider} (optional)`}
+                placeholder={
+                  provider === LLMProviderEnum.Groq
+                    ? 'https://api.groq.com/v1'
+                    : provider === LLMProviderEnum.Grok
+                      ? 'https://api.grok.x.ai/v1'
+                      : `Enter custom base URL for ${provider} (optional)`
+                }
               />
               />
             </div>
             </div>
           )}
           )}