llmProviders.ts 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import { StorageEnum } from '../base/enums';
  2. import { createStorage } from '../base/base';
  3. import type { BaseStorage } from '../base/types';
  4. import { llmProviderModelNames, ProviderTypeEnum } from './types';
  5. // Interface for a single provider configuration
  6. export interface ProviderConfig {
  7. name?: string; // Display name in the options
  8. type?: ProviderTypeEnum; // Help to decide which LangChain ChatModel package to use
  9. apiKey: string; // Must be provided, but may be empty for local models
  10. baseUrl?: string; // Optional base URL if provided
  11. modelNames?: string[]; // Chosen model names, if not provided use hardcoded names from llmProviderModelNames
  12. createdAt?: number; // Timestamp in milliseconds when the provider was created
  13. }
  14. // Interface for storing multiple LLM provider configurations
  15. // The key is the provider id, which is the same as the provider type for built-in providers, but is custom for custom providers
  16. export interface LLMKeyRecord {
  17. providers: Record<string, ProviderConfig>;
  18. }
  19. export type LLMProviderStorage = BaseStorage<LLMKeyRecord> & {
  20. setProvider: (providerId: string, config: ProviderConfig) => Promise<void>;
  21. getProvider: (providerId: string) => Promise<ProviderConfig | undefined>;
  22. removeProvider: (providerId: string) => Promise<void>;
  23. hasProvider: (providerId: string) => Promise<boolean>;
  24. getAllProviders: () => Promise<Record<string, ProviderConfig>>;
  25. };
  26. // Storage for LLM provider configurations
  27. // use "llm-api-keys" as the key for the storage, for backward compatibility
  28. const storage = createStorage<LLMKeyRecord>(
  29. 'llm-api-keys',
  30. { providers: {} },
  31. {
  32. storageEnum: StorageEnum.Local,
  33. liveUpdate: true,
  34. },
  35. );
  36. // Helper function to determine provider type from provider name
  37. function getProviderTypeByProviderId(providerId: string): ProviderTypeEnum {
  38. switch (providerId) {
  39. case ProviderTypeEnum.OpenAI:
  40. return ProviderTypeEnum.OpenAI;
  41. case ProviderTypeEnum.Anthropic:
  42. return ProviderTypeEnum.Anthropic;
  43. case ProviderTypeEnum.Gemini:
  44. return ProviderTypeEnum.Gemini;
  45. case ProviderTypeEnum.Ollama:
  46. return ProviderTypeEnum.Ollama;
  47. default:
  48. return ProviderTypeEnum.CustomOpenAI;
  49. }
  50. }
  51. // Helper function to get display name from provider id
  52. function getDisplayNameFromProviderId(providerId: string): string {
  53. switch (providerId) {
  54. case ProviderTypeEnum.OpenAI:
  55. return 'OpenAI';
  56. case ProviderTypeEnum.Anthropic:
  57. return 'Anthropic';
  58. case ProviderTypeEnum.Gemini:
  59. return 'Gemini';
  60. case ProviderTypeEnum.Ollama:
  61. return 'Ollama';
  62. default:
  63. return providerId; // Use the provider id as display name for custom providers by default
  64. }
  65. }
  66. // Helper function to ensure backward compatibility for provider configs
  67. function ensureBackwardCompatibility(providerId: string, config: ProviderConfig): ProviderConfig {
  68. const updatedConfig = { ...config };
  69. if (!updatedConfig.name) {
  70. updatedConfig.name = getDisplayNameFromProviderId(providerId);
  71. }
  72. if (!updatedConfig.type) {
  73. updatedConfig.type = getProviderTypeByProviderId(providerId);
  74. }
  75. if (!updatedConfig.modelNames) {
  76. updatedConfig.modelNames = llmProviderModelNames[providerId as keyof typeof llmProviderModelNames] || [];
  77. }
  78. if (!updatedConfig.createdAt) {
  79. // if createdAt is not set, set it to "03/04/2025" for backward compatibility
  80. updatedConfig.createdAt = new Date('03/04/2025').getTime();
  81. }
  82. return updatedConfig;
  83. }
  84. export const llmProviderStore: LLMProviderStorage = {
  85. ...storage,
  86. async setProvider(providerId: string, config: ProviderConfig) {
  87. if (!providerId) {
  88. throw new Error('Provider id cannot be empty');
  89. }
  90. if (config.apiKey === undefined) {
  91. throw new Error('API key must be provided (can be empty for local models)');
  92. }
  93. if (!config.modelNames) {
  94. throw new Error('Model names must be provided');
  95. }
  96. // Ensure backward compatibility by filling in missing fields
  97. const completeConfig: ProviderConfig = {
  98. ...config,
  99. name: config.name || getDisplayNameFromProviderId(providerId),
  100. type: config.type || getProviderTypeByProviderId(providerId),
  101. modelNames: config.modelNames,
  102. createdAt: config.createdAt || Date.now(),
  103. };
  104. const current = (await storage.get()) || { providers: {} };
  105. await storage.set({
  106. providers: {
  107. ...current.providers,
  108. [providerId]: completeConfig,
  109. },
  110. });
  111. },
  112. async getProvider(providerId: string) {
  113. const data = (await storage.get()) || { providers: {} };
  114. const config = data.providers[providerId];
  115. return config ? ensureBackwardCompatibility(providerId, config) : undefined;
  116. },
  117. async removeProvider(providerId: string) {
  118. const current = (await storage.get()) || { providers: {} };
  119. const newProviders = { ...current.providers };
  120. delete newProviders[providerId];
  121. await storage.set({ providers: newProviders });
  122. },
  123. async hasProvider(providerId: string) {
  124. const data = (await storage.get()) || { providers: {} };
  125. return providerId in data.providers;
  126. },
  127. async getAllProviders() {
  128. const data = await storage.get();
  129. const providers = { ...data.providers };
  130. // Add backward compatibility for all providers
  131. for (const [providerId, config] of Object.entries(providers)) {
  132. providers[providerId] = ensureBackwardCompatibility(providerId, config);
  133. }
  134. return providers;
  135. },
  136. };