llmProviders.ts 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import { StorageEnum } from '../base/enums';
  2. import { createStorage } from '../base/base';
  3. import type { BaseStorage } from '../base/types';
  4. import { type AgentNameEnum, llmProviderModelNames, llmProviderParameters, ProviderTypeEnum } from './types';
  5. // Interface for a single provider configuration
  6. export interface ProviderConfig {
  7. name?: string; // Display name in the options
  8. type?: ProviderTypeEnum; // Help to decide which LangChain ChatModel package to use
  9. apiKey: string; // Must be provided, but may be empty for local models
  10. baseUrl?: string; // Optional base URL if provided
  11. modelNames?: string[]; // Chosen model names, if not provided use hardcoded names from llmProviderModelNames
  12. createdAt?: number; // Timestamp in milliseconds when the provider was created
  13. }
  14. // Interface for storing multiple LLM provider configurations
  15. // The key is the provider id, which is the same as the provider type for built-in providers, but is custom for custom providers
  16. export interface LLMKeyRecord {
  17. providers: Record<string, ProviderConfig>;
  18. }
  19. export type LLMProviderStorage = BaseStorage<LLMKeyRecord> & {
  20. setProvider: (providerId: string, config: ProviderConfig) => Promise<void>;
  21. getProvider: (providerId: string) => Promise<ProviderConfig | undefined>;
  22. removeProvider: (providerId: string) => Promise<void>;
  23. hasProvider: (providerId: string) => Promise<boolean>;
  24. getAllProviders: () => Promise<Record<string, ProviderConfig>>;
  25. };
  26. // Storage for LLM provider configurations
  27. // use "llm-api-keys" as the key for the storage, for backward compatibility
  28. const storage = createStorage<LLMKeyRecord>(
  29. 'llm-api-keys',
  30. { providers: {} },
  31. {
  32. storageEnum: StorageEnum.Local,
  33. liveUpdate: true,
  34. },
  35. );
  36. // Helper function to determine provider type from provider name
  37. // Make sure to update this function if you add a new provider type
  38. export function getProviderTypeByProviderId(providerId: string): ProviderTypeEnum {
  39. switch (providerId) {
  40. case ProviderTypeEnum.OpenAI:
  41. case ProviderTypeEnum.Anthropic:
  42. case ProviderTypeEnum.DeepSeek:
  43. case ProviderTypeEnum.Gemini:
  44. case ProviderTypeEnum.Grok:
  45. case ProviderTypeEnum.Ollama:
  46. return providerId;
  47. default:
  48. return ProviderTypeEnum.CustomOpenAI;
  49. }
  50. }
  51. // Helper function to get display name from provider id
  52. // Make sure to update this function if you add a new provider type
  53. export function getDefaultDisplayNameFromProviderId(providerId: string): string {
  54. switch (providerId) {
  55. case ProviderTypeEnum.OpenAI:
  56. return 'OpenAI';
  57. case ProviderTypeEnum.Anthropic:
  58. return 'Anthropic';
  59. case ProviderTypeEnum.DeepSeek:
  60. return 'DeepSeek';
  61. case ProviderTypeEnum.Gemini:
  62. return 'Gemini';
  63. case ProviderTypeEnum.Grok:
  64. return 'Grok';
  65. case ProviderTypeEnum.Ollama:
  66. return 'Ollama';
  67. default:
  68. return providerId; // Use the provider id as display name for custom providers by default
  69. }
  70. }
  71. // Get default configuration for built-in providers
  72. // Make sure to update this function if you add a new provider type
  73. export function getDefaultProviderConfig(providerId: string): ProviderConfig {
  74. switch (providerId) {
  75. case ProviderTypeEnum.OpenAI:
  76. case ProviderTypeEnum.Anthropic:
  77. case ProviderTypeEnum.DeepSeek:
  78. case ProviderTypeEnum.Gemini:
  79. case ProviderTypeEnum.Grok:
  80. return {
  81. apiKey: '',
  82. name: getDefaultDisplayNameFromProviderId(providerId),
  83. type: providerId,
  84. modelNames: [...(llmProviderModelNames[providerId] || [])],
  85. createdAt: Date.now(),
  86. };
  87. case ProviderTypeEnum.Ollama:
  88. return {
  89. apiKey: 'ollama', // Set default API key for Ollama
  90. name: getDefaultDisplayNameFromProviderId(ProviderTypeEnum.Ollama),
  91. type: ProviderTypeEnum.Ollama,
  92. modelNames: [],
  93. baseUrl: 'http://localhost:11434',
  94. createdAt: Date.now(),
  95. };
  96. default:
  97. return {
  98. apiKey: '',
  99. name: getDefaultDisplayNameFromProviderId(providerId),
  100. type: ProviderTypeEnum.CustomOpenAI,
  101. baseUrl: '',
  102. modelNames: [],
  103. createdAt: Date.now(),
  104. };
  105. }
  106. }
  107. export function getDefaultAgentModelParams(providerId: string, agentName: AgentNameEnum): Record<string, number> {
  108. const newParameters = llmProviderParameters[providerId as keyof typeof llmProviderParameters]?.[agentName] || {
  109. temperature: 0.1,
  110. topP: 0.1,
  111. };
  112. return newParameters;
  113. }
  114. // Helper function to ensure backward compatibility for provider configs
  115. function ensureBackwardCompatibility(providerId: string, config: ProviderConfig): ProviderConfig {
  116. const updatedConfig = { ...config };
  117. if (!updatedConfig.name) {
  118. updatedConfig.name = getDefaultDisplayNameFromProviderId(providerId);
  119. }
  120. if (!updatedConfig.type) {
  121. updatedConfig.type = getProviderTypeByProviderId(providerId);
  122. }
  123. if (!updatedConfig.modelNames) {
  124. updatedConfig.modelNames = llmProviderModelNames[providerId as keyof typeof llmProviderModelNames] || [];
  125. }
  126. if (!updatedConfig.createdAt) {
  127. // if createdAt is not set, set it to "03/04/2025" for backward compatibility
  128. updatedConfig.createdAt = new Date('03/04/2025').getTime();
  129. }
  130. return updatedConfig;
  131. }
  132. export const llmProviderStore: LLMProviderStorage = {
  133. ...storage,
  134. async setProvider(providerId: string, config: ProviderConfig) {
  135. if (!providerId) {
  136. throw new Error('Provider id cannot be empty');
  137. }
  138. if (config.apiKey === undefined) {
  139. throw new Error('API key must be provided (can be empty for local models)');
  140. }
  141. if (!config.modelNames) {
  142. throw new Error('Model names must be provided');
  143. }
  144. // Ensure backward compatibility by filling in missing fields
  145. const completeConfig: ProviderConfig = {
  146. ...config,
  147. name: config.name || getDefaultDisplayNameFromProviderId(providerId),
  148. type: config.type || getProviderTypeByProviderId(providerId),
  149. modelNames: config.modelNames,
  150. createdAt: config.createdAt || Date.now(),
  151. };
  152. const current = (await storage.get()) || { providers: {} };
  153. await storage.set({
  154. providers: {
  155. ...current.providers,
  156. [providerId]: completeConfig,
  157. },
  158. });
  159. },
  160. async getProvider(providerId: string) {
  161. const data = (await storage.get()) || { providers: {} };
  162. const config = data.providers[providerId];
  163. return config ? ensureBackwardCompatibility(providerId, config) : undefined;
  164. },
  165. async removeProvider(providerId: string) {
  166. const current = (await storage.get()) || { providers: {} };
  167. const newProviders = { ...current.providers };
  168. delete newProviders[providerId];
  169. await storage.set({ providers: newProviders });
  170. },
  171. async hasProvider(providerId: string) {
  172. const data = (await storage.get()) || { providers: {} };
  173. return providerId in data.providers;
  174. },
  175. async getAllProviders() {
  176. const data = await storage.get();
  177. const providers = { ...data.providers };
  178. // Add backward compatibility for all providers
  179. for (const [providerId, config] of Object.entries(providers)) {
  180. providers[providerId] = ensureBackwardCompatibility(providerId, config);
  181. }
  182. return providers;
  183. },
  184. };