Browse Source

add model support for claude

tycoding 1 năm trước cách đây
mục cha
commit
354584e299

+ 1 - 0
langchat-biz/src/main/java/cn/tycoding/langchat/biz/component/ProviderEnum.java

@@ -29,6 +29,7 @@ public enum ProviderEnum {
     AZURE_OPENAI,
     GEMINI,
     OLLAMA,
+    CLAUDE,
     Q_FAN,
     Q_WEN,
     ZHIPU,

+ 5 - 0
langchat-core/pom.xml

@@ -58,6 +58,11 @@
             <artifactId>langchain4j-vertex-ai-gemini</artifactId>
             <version>${langchain4j.version}</version>
         </dependency>
+        <dependency>
+            <groupId>dev.langchain4j</groupId>
+            <artifactId>langchain4j-anthropic</artifactId>
+            <version>${langchain4j.version}</version>
+        </dependency>
         <dependency>
             <groupId>dev.langchain4j</groupId>
             <artifactId>langchain4j-qianfan</artifactId>

+ 15 - 2
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/ProviderInitialize.java

@@ -23,6 +23,7 @@ import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.biz.service.AigcModelService;
 import cn.tycoding.langchat.common.component.SpringContextHolder;
 import cn.tycoding.langchat.core.consts.EmbedConst;
+import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel;
 import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
 import dev.langchain4j.model.azure.AzureOpenAiImageModel;
 import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel;
@@ -81,13 +82,13 @@ public class ProviderInitialize implements ApplicationContextAware {
             // Uninstall previously registered beans before registering them
             contextHolder.unregisterBean(model.getId());
 
-            llmHandler(model);
+            chatHandler(model);
             embeddingHandler(model);
             imageHandler(model);
         });
     }
 
-    private void llmHandler(AigcModel model) {
+    private void chatHandler(AigcModel model) {
         String type = model.getType();
         String provider = model.getProvider();
 
@@ -154,6 +155,18 @@ public class ProviderInitialize implements ApplicationContextAware {
                         .build();
                 contextHolder.registerBean(model.getId(), build);
             }
+            if (ProviderEnum.CLAUDE.name().equals(provider)) {
+                AnthropicStreamingChatModel build = AnthropicStreamingChatModel
+                        .builder()
+                        .baseUrl(model.getBaseUrl())
+                        .modelName(model.getModel())
+                        .temperature(model.getTemperature())
+                        .topP(model.getTopP())
+                        .logRequests(true)
+                        .logResponses(true)
+                        .build();
+                contextHolder.registerBean(model.getId(), build);
+            }
             if (ProviderEnum.Q_FAN.name().equals(provider)) {
                 QianfanStreamingChatModel build = QianfanStreamingChatModel
                         .builder()

+ 1 - 1
langchat-server/src/main/resources/application.yml

@@ -8,7 +8,7 @@ spring:
     name: langchat
   # 默认执行的配置文件
   profiles:
-    active: dev
+    active: local
   main:
     allow-bean-definition-overriding: true
 

+ 11 - 0
langchat-ui/src/views/aigc/model/components/chat/columns.ts

@@ -90,6 +90,14 @@ export const ollamaColumns = [
   },
 ];
 
+export const claudeColumns = [
+  ...baseColumns,
+  {
+    title: 'Api Key',
+    key: 'apiKey',
+  },
+];
+
 export const qfanColumns = [
   ...baseColumns,
   {
@@ -115,6 +123,9 @@ export function getColumns(provider: string) {
     case ProviderEnum.OLLAMA: {
       return ollamaColumns;
     }
+    case ProviderEnum.CLAUDE: {
+      return claudeColumns;
+    }
     case ProviderEnum.Q_FAN: {
       return qfanColumns;
     }

+ 19 - 2
langchat-ui/src/views/aigc/model/components/chat/data.ts

@@ -19,6 +19,7 @@ export enum ProviderEnum {
   AZURE_OPENAI = 'AZURE_OPENAI',
   GEMINI = 'GEMINI',
   OLLAMA = 'OLLAMA',
+  CLAUDE = 'CLAUDE',
   Q_FAN = 'Q_FAN',
   Q_WEN = 'Q_WEN',
   ZHIPU = 'ZHIPU',
@@ -70,11 +71,23 @@ export const LLMProviders: any[] = [
   },
   {
     model: ProviderEnum.GEMINI,
-    name: 'GEMINI',
+    name: 'Gemini',
   },
   {
     model: ProviderEnum.OLLAMA,
-    name: 'OLLAMA',
+    name: 'Ollama',
+  },
+  {
+    model: ProviderEnum.CLAUDE,
+    name: 'Claude',
+    models: [
+      'claude-3-opus-20240229',
+      'claude-3-sonnet-20240229',
+      'claude-3-haiku-20240307',
+      'claude-2.1',
+      'claude-2.0',
+      'claude-instant-1.2',
+    ],
   },
   {
     model: ProviderEnum.Q_FAN,
@@ -120,3 +133,7 @@ export const LLMProviders: any[] = [
     models: ['glm-4', 'glm-3-turbo', 'chatglm_turbo'],
   },
 ];
+
+export function getTitle(provider: string) {
+  return LLMProviders.filter((i) => i.model === provider)[0].name;
+}

+ 3 - 1
langchat-ui/src/views/aigc/model/components/chat/edit.vue

@@ -21,6 +21,7 @@
   import { isNullOrWhitespace } from '@/utils/is';
   import { add, update } from '@/api/aigc/model';
   import { useMessage } from 'naive-ui';
+  import { getTitle } from './data';
 
   const props = defineProps<{
     provider: string;
@@ -32,7 +33,7 @@
   const title = computed(() => {
     return info.value == undefined || info.value.provider == undefined
       ? 'Add Model'
-      : info.value.provider;
+      : getTitle(info.value.provider);
   });
   const form: any = {
     responseLimit: 2000,
@@ -49,6 +50,7 @@
     isShow.value = true;
     await nextTick();
     info.value = record;
+    console.log(record);
     setFieldsValue({ ...form, ...record });
   }
 

+ 33 - 1
langchat-ui/src/views/aigc/model/components/chat/schemas.ts

@@ -206,11 +206,40 @@ export const ollamaSchemas: FormSchema[] = [
     label: 'Base Url',
     labelMessage: '注意对于大多数模型此参数仅代表中转地址,但是对于Ollama这类本地模型则是必填的',
     component: 'NInput',
-    rules: [{ required: true, message: '请输入Base Url', trigger: ['blur'] }],
+    rules: [
+      {
+        required: false,
+        trigger: ['blur'],
+        validator: (_, value: string) => {
+          const urlRegex = /^(https?:\/\/)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(\/.*)?$/;
+          if (isNullOrWhitespace(value) || urlRegex.test(value)) {
+            return true;
+          }
+          return new Error('URL格式错误');
+        },
+      },
+    ],
   },
   ...baseSchemas,
 ];
 
+export const claudeSchemas: FormSchema[] = [
+  ...baseHeadSchemas,
+  {
+    field: 'model',
+    label: '模型',
+    labelMessage: '该LLM供应商对应的模型版本号',
+    component: 'NSelect',
+    rules: [{ required: true, message: '请选择模型', trigger: ['blur'] }],
+    componentProps: {
+      filterable: true,
+      options: getModels(ProviderEnum.CLAUDE),
+    },
+  },
+  ...keySchemas,
+  ...baseSchemas,
+];
+
 export const qfanSchemas: FormSchema[] = [
   ...baseHeadSchemas,
   {
@@ -283,6 +312,9 @@ export function getSchemas(provider: string) {
     case ProviderEnum.OLLAMA: {
       return ollamaSchemas;
     }
+    case ProviderEnum.CLAUDE: {
+      return claudeSchemas;
+    }
     case ProviderEnum.Q_FAN: {
       return qfanSchemas;
     }