|
@@ -11,6 +11,7 @@ import {
|
|
|
ANTHROPIC_PROVIDER,
|
|
|
GEMINI_PROVIDER,
|
|
|
OLLAMA_PROVIDER,
|
|
|
+ llmProviderParameters,
|
|
|
} from '@extension/storage';
|
|
|
|
|
|
export const ModelSettings = () => {
|
|
@@ -34,6 +35,11 @@ export const ModelSettings = () => {
|
|
|
[AgentNameEnum.Planner]: '',
|
|
|
[AgentNameEnum.Validator]: '',
|
|
|
});
|
|
|
+ const [modelParameters, setModelParameters] = useState<Record<AgentNameEnum, { temperature: number; topP: number }>>({
|
|
|
+ [AgentNameEnum.Navigator]: { temperature: 0, topP: 0 },
|
|
|
+ [AgentNameEnum.Planner]: { temperature: 0, topP: 0 },
|
|
|
+ [AgentNameEnum.Validator]: { temperature: 0, topP: 0 },
|
|
|
+ });
|
|
|
const [newModelInputs, setNewModelInputs] = useState<Record<string, string>>({});
|
|
|
const [isProviderSelectorOpen, setIsProviderSelectorOpen] = useState(false);
|
|
|
const newlyAddedProviderRef = useRef<string | null>(null);
|
|
@@ -63,7 +69,7 @@ export const ModelSettings = () => {
|
|
|
loadProviders();
|
|
|
}, []);
|
|
|
|
|
|
- // Load existing agent models on mount
|
|
|
+ // Load existing agent models and parameters on mount
|
|
|
useEffect(() => {
|
|
|
const loadAgentModels = async () => {
|
|
|
try {
|
|
@@ -77,6 +83,15 @@ export const ModelSettings = () => {
|
|
|
const config = await agentModelStore.getAgentModel(agent);
|
|
|
if (config) {
|
|
|
models[agent] = config.modelName;
|
|
|
+ if (config.parameters?.temperature !== undefined || config.parameters?.topP !== undefined) {
|
|
|
+ setModelParameters(prev => ({
|
|
|
+ ...prev,
|
|
|
+ [agent]: {
|
|
|
+ temperature: config.parameters?.temperature ?? prev[agent].temperature,
|
|
|
+ topP: config.parameters?.topP ?? prev[agent].topP,
|
|
|
+ },
|
|
|
+ }));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
setSelectedModels(models);
|
|
@@ -271,6 +286,7 @@ export const ModelSettings = () => {
|
|
|
name: providers[provider].name,
|
|
|
modelNames: modelNames,
|
|
|
type: providers[provider].type,
|
|
|
+ createdAt: providers[provider].createdAt,
|
|
|
});
|
|
|
|
|
|
// Clear any name errors on successful save
|
|
@@ -351,33 +367,41 @@ export const ModelSettings = () => {
|
|
|
};
|
|
|
|
|
|
const getAvailableModels = () => {
|
|
|
- const models: string[] = [];
|
|
|
+ const models: Array<{ provider: string; providerName: string; model: string }> = [];
|
|
|
|
|
|
- // First add models from configured providers
|
|
|
+ // Only get models from configured providers
|
|
|
for (const [provider, config] of Object.entries(providers)) {
|
|
|
if (config.apiKey) {
|
|
|
const providerModels =
|
|
|
config.modelNames || llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
|
|
|
- models.push(...providerModels);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // If no models are available, return default models for the "Add Provider" buttons
|
|
|
- if (models.length === 0) {
|
|
|
- // Include default models for the default providers
|
|
|
- const defaultProviders = [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GEMINI_PROVIDER, OLLAMA_PROVIDER];
|
|
|
- for (const provider of defaultProviders) {
|
|
|
- if (!providersFromStorage.has(provider) && !modifiedProviders.has(provider)) {
|
|
|
- const defaultModels = llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
|
|
|
- models.push(...defaultModels);
|
|
|
- }
|
|
|
+ models.push(
|
|
|
+ ...providerModels.map(model => ({
|
|
|
+ provider,
|
|
|
+ providerName: config.name || provider,
|
|
|
+ model,
|
|
|
+ })),
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return models.length ? models : [''];
|
|
|
+ return models;
|
|
|
};
|
|
|
|
|
|
- const handleModelChange = async (agentName: AgentNameEnum, model: string) => {
|
|
|
+ const handleModelChange = async (agentName: AgentNameEnum, modelValue: string) => {
|
|
|
+ // modelValue will be in format "provider>model"
|
|
|
+ const [provider, model] = modelValue.split('>');
|
|
|
+
|
|
|
+ // Set parameters based on provider type
|
|
|
+ const newParameters = llmProviderParameters[provider as keyof typeof llmProviderParameters]?.[agentName] || {
|
|
|
+ temperature: 0.1,
|
|
|
+ topP: 0.1,
|
|
|
+ };
|
|
|
+
|
|
|
+ setModelParameters(prev => ({
|
|
|
+ ...prev,
|
|
|
+ [agentName]: newParameters,
|
|
|
+ }));
|
|
|
+
|
|
|
setSelectedModels(prev => ({
|
|
|
...prev,
|
|
|
[agentName]: model,
|
|
@@ -385,56 +409,166 @@ export const ModelSettings = () => {
|
|
|
|
|
|
try {
|
|
|
if (model) {
|
|
|
- // Determine provider from model name
|
|
|
+ await agentModelStore.setAgentModel(agentName, {
|
|
|
+ provider,
|
|
|
+ modelName: model,
|
|
|
+ parameters: newParameters,
|
|
|
+ });
|
|
|
+ } else {
|
|
|
+ // Reset storage if no model is selected
|
|
|
+ await agentModelStore.resetAgentModel(agentName);
|
|
|
+ }
|
|
|
+ } catch (error) {
|
|
|
+ console.error('Error saving agent model:', error);
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ const handleParameterChange = async (agentName: AgentNameEnum, paramName: 'temperature' | 'topP', value: number) => {
|
|
|
+ const newParameters = {
|
|
|
+ ...modelParameters[agentName],
|
|
|
+ [paramName]: value,
|
|
|
+ };
|
|
|
+
|
|
|
+ setModelParameters(prev => ({
|
|
|
+ ...prev,
|
|
|
+ [agentName]: newParameters,
|
|
|
+ }));
|
|
|
+
|
|
|
+ // Only update if we have a selected model
|
|
|
+ if (selectedModels[agentName]) {
|
|
|
+ try {
|
|
|
+ // Find provider
|
|
|
let provider: string | undefined;
|
|
|
for (const [providerKey, providerConfig] of Object.entries(providers)) {
|
|
|
const modelNames =
|
|
|
providerConfig.modelNames || llmProviderModelNames[providerKey as keyof typeof llmProviderModelNames] || [];
|
|
|
- if (modelNames.includes(model)) {
|
|
|
+ if (modelNames.includes(selectedModels[agentName])) {
|
|
|
provider = providerKey;
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- console.log('handleModelChange', provider, model);
|
|
|
if (provider) {
|
|
|
await agentModelStore.setAgentModel(agentName, {
|
|
|
provider,
|
|
|
- modelName: model,
|
|
|
+ modelName: selectedModels[agentName],
|
|
|
+ parameters: newParameters,
|
|
|
});
|
|
|
}
|
|
|
- } else {
|
|
|
- // Reset storage if no model is selected
|
|
|
- await agentModelStore.resetAgentModel(agentName);
|
|
|
+ } catch (error) {
|
|
|
+ console.error('Error saving agent parameters:', error);
|
|
|
}
|
|
|
- } catch (error) {
|
|
|
- console.error('Error saving agent model:', error);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
const renderModelSelect = (agentName: AgentNameEnum) => (
|
|
|
- <div className="flex items-center justify-between">
|
|
|
- <div>
|
|
|
- <h3 className="text-lg font-medium text-gray-700">{agentName.charAt(0).toUpperCase() + agentName.slice(1)}</h3>
|
|
|
- <p className="text-sm font-normal text-gray-500">{getAgentDescription(agentName)}</p>
|
|
|
- </div>
|
|
|
- <select
|
|
|
- className="w-64 px-3 py-2 border rounded-md"
|
|
|
- disabled={getAvailableModels().length <= 1}
|
|
|
- value={selectedModels[agentName] || ''}
|
|
|
- onChange={e => handleModelChange(agentName, e.target.value)}>
|
|
|
- <option key="default" value="">
|
|
|
- Choose model
|
|
|
- </option>
|
|
|
- {getAvailableModels().map(
|
|
|
- model =>
|
|
|
- model && (
|
|
|
- <option key={model} value={model}>
|
|
|
- {model}
|
|
|
+ <div className="bg-white p-4 rounded-lg border border-gray-200">
|
|
|
+ <h3 className="text-lg font-medium text-gray-700 mb-2">
|
|
|
+ {agentName.charAt(0).toUpperCase() + agentName.slice(1)}
|
|
|
+ </h3>
|
|
|
+ <p className="text-sm font-normal text-gray-500 mb-4">{getAgentDescription(agentName)}</p>
|
|
|
+
|
|
|
+ <div className="space-y-4">
|
|
|
+ {/* Model Selection */}
|
|
|
+ <div className="flex items-center">
|
|
|
+ <label htmlFor={`${agentName}-model`} className="w-24 text-sm font-medium text-gray-700">
|
|
|
+ Model
|
|
|
+ </label>
|
|
|
+ <select
|
|
|
+ id={`${agentName}-model`}
|
|
|
+ className="flex-1 px-3 py-2 border rounded-md"
|
|
|
+ disabled={getAvailableModels().length <= 1}
|
|
|
+ value={
|
|
|
+ selectedModels[agentName]
|
|
|
+ ? `${getProviderForModel(selectedModels[agentName])}>${selectedModels[agentName]}`
|
|
|
+ : ''
|
|
|
+ }
|
|
|
+ onChange={e => handleModelChange(agentName, e.target.value)}>
|
|
|
+ <option key="default" value="">
|
|
|
+ Choose model
|
|
|
+ </option>
|
|
|
+ {getAvailableModels().map(({ provider, providerName, model }) => (
|
|
|
+ <option key={`${provider}>${model}`} value={`${provider}>${model}`}>
|
|
|
+ {`${providerName} > ${model}`}
|
|
|
</option>
|
|
|
- ),
|
|
|
- )}
|
|
|
- </select>
|
|
|
+ ))}
|
|
|
+ </select>
|
|
|
+ </div>
|
|
|
+
|
|
|
+ {/* Temperature Slider */}
|
|
|
+ <div className="flex items-center">
|
|
|
+ <label htmlFor={`${agentName}-temperature`} className="w-24 text-sm font-medium text-gray-700">
|
|
|
+ Temperature
|
|
|
+ </label>
|
|
|
+ <div className="flex-1 flex items-center space-x-2">
|
|
|
+ <input
|
|
|
+ id={`${agentName}-temperature`}
|
|
|
+ type="range"
|
|
|
+ min="0"
|
|
|
+ max="2"
|
|
|
+ step="0.01"
|
|
|
+ value={modelParameters[agentName].temperature}
|
|
|
+ onChange={e => handleParameterChange(agentName, 'temperature', Number.parseFloat(e.target.value))}
|
|
|
+ className="flex-1"
|
|
|
+ />
|
|
|
+ <div className="flex items-center space-x-2">
|
|
|
+ <span className="text-sm text-gray-600 w-12">{modelParameters[agentName].temperature.toFixed(2)}</span>
|
|
|
+ <input
|
|
|
+ type="number"
|
|
|
+ min="0"
|
|
|
+ max="2"
|
|
|
+ step="0.01"
|
|
|
+ value={modelParameters[agentName].temperature}
|
|
|
+ onChange={e => {
|
|
|
+ const value = Number.parseFloat(e.target.value);
|
|
|
+ if (!Number.isNaN(value) && value >= 0 && value <= 2) {
|
|
|
+ handleParameterChange(agentName, 'temperature', value);
|
|
|
+ }
|
|
|
+ }}
|
|
|
+ className="w-20 px-2 py-1 text-sm border rounded-md"
|
|
|
+ aria-label={`${agentName} temperature number input`}
|
|
|
+ />
|
|
|
+ </div>
|
|
|
+ </div>
|
|
|
+ </div>
|
|
|
+
|
|
|
+ {/* Top P Slider */}
|
|
|
+ <div className="flex items-center">
|
|
|
+ <label htmlFor={`${agentName}-topP`} className="w-24 text-sm font-medium text-gray-700">
|
|
|
+ Top P
|
|
|
+ </label>
|
|
|
+ <div className="flex-1 flex items-center space-x-2">
|
|
|
+ <input
|
|
|
+ id={`${agentName}-topP`}
|
|
|
+ type="range"
|
|
|
+ min="0"
|
|
|
+ max="1"
|
|
|
+ step="0.001"
|
|
|
+ value={modelParameters[agentName].topP}
|
|
|
+ onChange={e => handleParameterChange(agentName, 'topP', Number.parseFloat(e.target.value))}
|
|
|
+ className="flex-1"
|
|
|
+ />
|
|
|
+ <div className="flex items-center space-x-2">
|
|
|
+ <span className="text-sm text-gray-600 w-12">{modelParameters[agentName].topP.toFixed(3)}</span>
|
|
|
+ <input
|
|
|
+ type="number"
|
|
|
+ min="0"
|
|
|
+ max="1"
|
|
|
+ step="0.001"
|
|
|
+ value={modelParameters[agentName].topP}
|
|
|
+ onChange={e => {
|
|
|
+ const value = Number.parseFloat(e.target.value);
|
|
|
+ if (!Number.isNaN(value) && value >= 0 && value <= 1) {
|
|
|
+ handleParameterChange(agentName, 'topP', value);
|
|
|
+ }
|
|
|
+ }}
|
|
|
+ className="w-20 px-2 py-1 text-sm border rounded-md"
|
|
|
+ aria-label={`${agentName} top P number input`}
|
|
|
+ />
|
|
|
+ </div>
|
|
|
+ </div>
|
|
|
+ </div>
|
|
|
+ </div>
|
|
|
</div>
|
|
|
);
|
|
|
|
|
@@ -645,6 +779,17 @@ export const ModelSettings = () => {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+ const getProviderForModel = (modelName: string): string => {
|
|
|
+ for (const [provider, config] of Object.entries(providers)) {
|
|
|
+ const modelNames =
|
|
|
+ config.modelNames || llmProviderModelNames[provider as keyof typeof llmProviderModelNames] || [];
|
|
|
+ if (modelNames.includes(modelName)) {
|
|
|
+ return provider;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return '';
|
|
|
+ };
|
|
|
+
|
|
|
return (
|
|
|
<section className="space-y-6">
|
|
|
{/* LLM Providers Section */}
|
|
@@ -764,22 +909,24 @@ export const ModelSettings = () => {
|
|
|
{/* Base URL input (for custom_openai and ollama) */}
|
|
|
{(providerConfig.type === ProviderTypeEnum.CustomOpenAI ||
|
|
|
providerConfig.type === ProviderTypeEnum.Ollama) && (
|
|
|
- <div className="flex items-center">
|
|
|
- <label htmlFor={`${providerId}-base-url`} className="w-20 text-sm font-medium text-gray-700">
|
|
|
- Base URL{providerConfig.type === ProviderTypeEnum.CustomOpenAI ? '*' : ''}
|
|
|
- </label>
|
|
|
- <input
|
|
|
- id={`${providerId}-base-url`}
|
|
|
- type="text"
|
|
|
- placeholder={
|
|
|
- providerConfig.type === ProviderTypeEnum.CustomOpenAI
|
|
|
- ? 'Required for custom OpenAI providers'
|
|
|
- : 'Ollama base URL'
|
|
|
- }
|
|
|
- value={providerConfig.baseUrl || ''}
|
|
|
- onChange={e => handleApiKeyChange(providerId, providerConfig.apiKey || '', e.target.value)}
|
|
|
- className="flex-1 p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
|
|
|
- />
|
|
|
+ <div className="flex flex-col">
|
|
|
+ <div className="flex items-center">
|
|
|
+ <label htmlFor={`${providerId}-base-url`} className="w-20 text-sm font-medium text-gray-700">
|
|
|
+ Base URL{providerConfig.type === ProviderTypeEnum.CustomOpenAI ? '*' : ''}
|
|
|
+ </label>
|
|
|
+ <input
|
|
|
+ id={`${providerId}-base-url`}
|
|
|
+ type="text"
|
|
|
+ placeholder={
|
|
|
+ providerConfig.type === ProviderTypeEnum.CustomOpenAI
|
|
|
+ ? 'Required for custom OpenAI providers'
|
|
|
+ : 'Ollama base URL'
|
|
|
+ }
|
|
|
+ value={providerConfig.baseUrl || ''}
|
|
|
+ onChange={e => handleApiKeyChange(providerId, providerConfig.apiKey || '', e.target.value)}
|
|
|
+ className="flex-1 p-2 rounded-md bg-gray-50 border border-gray-200 focus:border-blue-400 focus:ring-2 focus:ring-blue-200 outline-none"
|
|
|
+ />
|
|
|
+ </div>
|
|
|
</div>
|
|
|
)}
|
|
|
|
|
@@ -828,6 +975,24 @@ export const ModelSettings = () => {
|
|
|
<p className="text-xs text-gray-500 mt-1">Type and Press Enter or Space to add a model</p>
|
|
|
</div>
|
|
|
</div>
|
|
|
+
|
|
|
+ {/* Ollama reminder at the bottom of the section */}
|
|
|
+ {providerConfig.type === ProviderTypeEnum.Ollama && (
|
|
|
+ <div className="mt-4 p-3 bg-amber-50 border border-amber-200 rounded-md">
|
|
|
+ <p className="text-sm text-amber-700">
|
|
|
+ <strong>Remember:</strong> Add{' '}
|
|
|
+ <code className="bg-amber-100 px-1 py-0.5 rounded">OLLAMA_ORIGINS=chrome-extension://*</code>{' '}
|
|
|
+ environment variable for the Ollama server.
|
|
|
+ <a
|
|
|
+ href="https://github.com/ollama/ollama/issues/6489"
|
|
|
+ target="_blank"
|
|
|
+ rel="noopener noreferrer"
|
|
|
+ className="text-blue-600 hover:text-blue-800 ml-1">
|
|
|
+ Learn more
|
|
|
+ </a>
|
|
|
+ </p>
|
|
|
+ </div>
|
|
|
+ )}
|
|
|
</div>
|
|
|
|
|
|
{/* Add divider except for the last item */}
|