Explorar o código

Fix model setting layout issue

alexchenzl hai 5 meses
pai
achega
e5bf77e6f1
Modificáronse 1 ficheiros con 229 adicións e 64 borrados
  1. 229 64
      pages/options/src/components/ModelSettings.tsx

+ 229 - 64
pages/options/src/components/ModelSettings.tsx

@@ -11,6 +11,7 @@ import {
   ANTHROPIC_PROVIDER,
   GEMINI_PROVIDER,
   OLLAMA_PROVIDER,
+  llmProviderParameters,
 } from '@extension/storage';
 
 export const ModelSettings = () => {
@@ -34,6 +35,11 @@ export const ModelSettings = () => {
     [AgentNameEnum.Planner]: '',
     [AgentNameEnum.Validator]: '',
   });
+  const [modelParameters, setModelParameters] = useState<Record<AgentNameEnum, { temperature: number; topP: number }>>({
+    [AgentNameEnum.Navigator]: { temperature: 0, topP: 0 },
+    [AgentNameEnum.Planner]: { temperature: 0, topP: 0 },
+    [AgentNameEnum.Validator]: { temperature: 0, topP: 0 },
+  });
   const [newModelInputs, setNewModelInputs] = useState<Record<string, string>>({});
   const [isProviderSelectorOpen, setIsProviderSelectorOpen] = useState(false);
   const newlyAddedProviderRef = useRef<string | null>(null);
@@ -63,7 +69,7 @@ export const ModelSettings = () => {
     loadProviders();
   }, []);
 
-  // Load existing agent models on mount
+  // Load existing agent models and parameters on mount
   useEffect(() => {
     const loadAgentModels = async () => {
       try {
@@ -77,6 +83,15 @@ export const ModelSettings = () => {
           const config = await agentModelStore.getAgentModel(agent);
           if (config) {
             models[agent] = config.modelName;
+            if (config.parameters?.temperature !== undefined || config.parameters?.topP !== undefined) {
+              setModelParameters(prev => ({
+                ...prev,
+                [agent]: {
+                  temperature: config.parameters?.temperature ?? prev[agent].temperature,
+                  topP: config.parameters?.topP ?? prev[agent].topP,
+                },
+              }));
+            }
           }
         }
         setSelectedModels(models);
@@ -271,6 +286,7 @@ export const ModelSettings = () => {
         name: providers[provider].name,
         modelNames: modelNames,
         type: providers[provider].type,
+        createdAt: providers[provider].createdAt,
       });
 
       // Clear any name errors on successful save
@@ -351,33 +367,41 @@ export const ModelSettings = () => {
   };
 
   const getAvailableModels = () => {
-    const models: string[] = [];
+    const models: Array<{ provider: string; providerName: string; model: string }> = [];
 
-    // First add models from configured providers
+    // Only get models from configured providers
     for (const [provider, config] of Object.entries(providers)) {
       if (config.apiKey) {
         const providerModels =
           config.modelNames || llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
-        models.push(...providerModels);
-      }
-    }
-
-    // If no models are available, return default models for the "Add Provider" buttons
-    if (models.length === 0) {
-      // Include default models for the default providers
-      const defaultProviders = [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GEMINI_PROVIDER, OLLAMA_PROVIDER];
-      for (const provider of defaultProviders) {
-        if (!providersFromStorage.has(provider) && !modifiedProviders.has(provider)) {
-          const defaultModels = llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
-          models.push(...defaultModels);
-        }
+        models.push(
+          ...providerModels.map(model => ({
+            provider,
+            providerName: config.name || provider,
+            model,
+          })),
+        );
       }
     }
 
-    return models.length ? models : [''];
+    return models;
   };
 
-  const handleModelChange = async (agentName: AgentNameEnum, model: string) => {
+  const handleModelChange = async (agentName: AgentNameEnum, modelValue: string) => {
+    // modelValue will be in format "provider>model"
+    const [provider, model] = modelValue.split('>');
+
+    // Set parameters based on provider type
+    const newParameters = llmProviderParameters[provider as keyof typeof llmProviderParameters]?.[agentName] || {
+      temperature: 0.1,
+      topP: 0.1,
+    };
+
+    setModelParameters(prev => ({
+      ...prev,
+      [agentName]: newParameters,
+    }));
+
     setSelectedModels(prev => ({
       ...prev,
       [agentName]: model,
@@ -385,56 +409,166 @@ export const ModelSettings = () => {
 
     try {
       if (model) {
-        // Determine provider from model name
+        await agentModelStore.setAgentModel(agentName, {
+          provider,
+          modelName: model,
+          parameters: newParameters,
+        });
+      } else {
+        // Reset storage if no model is selected
+        await agentModelStore.resetAgentModel(agentName);
+      }
+    } catch (error) {
+      console.error('Error saving agent model:', error);
+    }
+  };
+
+  const handleParameterChange = async (agentName: AgentNameEnum, paramName: 'temperature' | 'topP', value: number) => {
+    const newParameters = {
+      ...modelParameters[agentName],
+      [paramName]: value,
+    };
+
+    setModelParameters(prev => ({
+      ...prev,
+      [agentName]: newParameters,
+    }));
+
+    // Only update if we have a selected model
+    if (selectedModels[agentName]) {
+      try {
+        // Find provider
         let provider: string | undefined;
         for (const [providerKey, providerConfig] of Object.entries(providers)) {
           const modelNames =
             providerConfig.modelNames || llmProviderModelNames[providerKey as keyof typeof llmProviderModelNames] || [];
-          if (modelNames.includes(model)) {
+          if (modelNames.includes(selectedModels[agentName])) {
             provider = providerKey;
             break;
           }
         }
 
-        console.log('handleModelChange', provider, model);
         if (provider) {
           await agentModelStore.setAgentModel(agentName, {
             provider,
-            modelName: model,
+            modelName: selectedModels[agentName],
+            parameters: newParameters,
           });
         }
-      } else {
-        // Reset storage if no model is selected
-        await agentModelStore.resetAgentModel(agentName);
+      } catch (error) {
+        console.error('Error saving agent parameters:', error);
       }
-    } catch (error) {
-      console.error('Error saving agent model:', error);
     }
   };
 
   const renderModelSelect = (agentName: AgentNameEnum) => (
-    <div className="flex items-center justify-between">
-      <div>
-        <h3 className="text-lg font-medium text-gray-700">{agentName.charAt(0).toUpperCase() + agentName.slice(1)}</h3>
-        <p className="text-sm font-normal text-gray-500">{getAgentDescription(agentName)}</p>
-      </div>
-      <select
-        className="w-64 px-3 py-2 border rounded-md"
-        disabled={getAvailableModels().length <= 1}
-        value={selectedModels[agentName] || ''}
-        onChange={e => handleModelChange(agentName, e.target.value)}>
-        <option key="default" value="">
-          Choose model
-        </option>
-        {getAvailableModels().map(
-          model =>
-            model && (
-              <option key={model} value={model}>
-                {model}
+    <div className="bg-white p-4 rounded-lg border border-gray-200">
+      <h3 className="text-lg font-medium text-gray-700 mb-2">
+        {agentName.charAt(0).toUpperCase() + agentName.slice(1)}
+      </h3>
+      <p className="text-sm font-normal text-gray-500 mb-4">{getAgentDescription(agentName)}</p>
+
+      <div className="space-y-4">
+        {/* Model Selection */}
+        <div className="flex items-center">
+          <label htmlFor={`${agentName}-model`} className="w-24 text-sm font-medium text-gray-700">
+            Model
+          </label>
+          <select
+            id={`${agentName}-model`}
+            className="flex-1 px-3 py-2 border rounded-md"
+            disabled={getAvailableModels().length <= 1}
+            value={
+              selectedModels[agentName]
+                ? `${getProviderForModel(selectedModels[agentName])}>${selectedModels[agentName]}`
+                : ''
+            }
+            onChange={e => handleModelChange(agentName, e.target.value)}>
+            <option key="default" value="">
+              Choose model
+            </option>
+            {getAvailableModels().map(({ provider, providerName, model }) => (
+              <option key={`${provider}>${model}`} value={`${provider}>${model}`}>
+                {`${providerName} > ${model}`}
               </option>
-            ),
-        )}
-      </select>
+            ))}
+          </select>
+        </div>
+
+        {/* Temperature Slider */}
+        <div className="flex items-center">
+          <label htmlFor={`${agentName}-temperature`} className="w-24 text-sm font-medium text-gray-700">
+            Temperature
+          </label>
+          <div className="flex-1 flex items-center space-x-2">
+            <input
+              id={`${agentName}-temperature`}
+              type="range"
+              min="0"
+              max="2"
+              step="0.01"
+              value={modelParameters[agentName].temperature}
+              onChange={e => handleParameterChange(agentName, 'temperature', Number.parseFloat(e.target.value))}
+              className="flex-1"
+            />
+            <div className="flex items-center space-x-2">
+              <span className="text-sm text-gray-600 w-12">{modelParameters[agentName].temperature.toFixed(2)}</span>
+              <input
+                type="number"
+                min="0"
+                max="2"
+                step="0.01"
+                value={modelParameters[agentName].temperature}
+                onChange={e => {
+                  const value = Number.parseFloat(e.target.value);
+                  if (!Number.isNaN(value) && value >= 0 && value <= 2) {
+                    handleParameterChange(agentName, 'temperature', value);
+                  }
+                }}
+                className="w-20 px-2 py-1 text-sm border rounded-md"
+                aria-label={`${agentName} temperature number input`}
+              />
+            </div>
+          </div>
+        </div>
+
+        {/* Top P Slider */}
+        <div className="flex items-center">
+          <label htmlFor={`${agentName}-topP`} className="w-24 text-sm font-medium text-gray-700">
+            Top P
+          </label>
+          <div className="flex-1 flex items-center space-x-2">
+            <input
+              id={`${agentName}-topP`}
+              type="range"
+              min="0"
+              max="1"
+              step="0.001"
+              value={modelParameters[agentName].topP}
+              onChange={e => handleParameterChange(agentName, 'topP', Number.parseFloat(e.target.value))}
+              className="flex-1"
+            />
+            <div className="flex items-center space-x-2">
+              <span className="text-sm text-gray-600 w-12">{modelParameters[agentName].topP.toFixed(3)}</span>
+              <input
+                type="number"
+                min="0"
+                max="1"
+                step="0.001"
+                value={modelParameters[agentName].topP}
+                onChange={e => {
+                  const value = Number.parseFloat(e.target.value);
+                  if (!Number.isNaN(value) && value >= 0 && value <= 1) {
+                    handleParameterChange(agentName, 'topP', value);
+                  }
+                }}
+                className="w-20 px-2 py-1 text-sm border rounded-md"
+                aria-label={`${agentName} top P number input`}
+              />
+            </div>
+          </div>
+        </div>
+      </div>
     </div>
   );
 
@@ -645,6 +779,17 @@ export const ModelSettings = () => {
     }
   };
 
+  const getProviderForModel = (modelName: string): string => {
+    for (const [provider, config] of Object.entries(providers)) {
+      const modelNames =
+        config.modelNames || llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
+      if (modelNames.includes(modelName)) {
+        return provider;
+      }
+    }
+    return '';
+  };
+
   return (
     <section className="space-y-6">
       {/* LLM Providers Section */}
@@ -764,22 +909,24 @@ export const ModelSettings = () => {
                   {/* Base URL input (for custom_openai and ollama) */}
                   {(providerConfig.type === ProviderTypeEnum.CustomOpenAI ||
                     providerConfig.type === ProviderTypeEnum.Ollama) && (
-                    <div className="flex items-center">
-                      <label htmlFor={`${providerId}-base-url`} className="w-20 text-sm font-medium text-gray-700">
-                        Base URL{providerConfig.type === ProviderTypeEnum.CustomOpenAI ? '*' : ''}
-                      </label>
-                      <input
-                        id={`${providerId}-base-url`}
-                        type="text"
-                        placeholder={
-                          providerConfig.type === ProviderTypeEnum.CustomOpenAI
-                            ? 'Required for custom OpenAI providers'
-                            : 'Ollama base URL'
-                        }
-                        value={providerConfig.baseUrl || ''}
-                        onChange={e => handleApiKeyChange(providerId, providerConfig.apiKey || '', e.target.value)}
-                        className="flex-1 p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
-                      />
+                    <div className="flex flex-col">
+                      <div className="flex items-center">
+                        <label htmlFor={`${providerId}-base-url`} className="w-20 text-sm font-medium text-gray-700">
+                          Base URL{providerConfig.type === ProviderTypeEnum.CustomOpenAI ? '*' : ''}
+                        </label>
+                        <input
+                          id={`${providerId}-base-url`}
+                          type="text"
+                          placeholder={
+                            providerConfig.type === ProviderTypeEnum.CustomOpenAI
+                              ? 'Required for custom OpenAI providers'
+                              : 'Ollama base URL'
+                          }
+                          value={providerConfig.baseUrl || ''}
+                          onChange={e => handleApiKeyChange(providerId, providerConfig.apiKey || '', e.target.value)}
+                          className="flex-1 p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
+                        />
+                      </div>
                     </div>
                   )}
 
@@ -828,6 +975,24 @@ export const ModelSettings = () => {
                       <p className="text-xs text-gray-500 mt-1">Type and Press Enter or Space to add a model</p>
                     </div>
                   </div>
+
+                  {/* Ollama reminder at the bottom of the section */}
+                  {providerConfig.type === ProviderTypeEnum.Ollama && (
+                    <div className="mt-4 p-3 bg-amber-50 border border-amber-200 rounded-md">
+                      <p className="text-sm text-amber-700">
+                        <strong>Remember:</strong> Add{' '}
+                        <code className="bg-amber-100 px-1 py-0.5 rounded">OLLAMA_ORIGINS=chrome-extension://*</code>{' '}
+                        environment variable for the Ollama server.
+                        <a
+                          href="https://github.com/ollama/ollama/issues/6489"
+                          target="_blank"
+                          rel="noopener noreferrer"
+                          className="text-blue-600 hover:text-blue-800 ml-1">
+                          Learn more
+                        </a>
+                      </p>
+                    </div>
+                  )}
                 </div>
 
                 {/* Add divider except for the last item */}