tycoding 1 rok pred
rodič
commit
9869e6f0ac

+ 15 - 0
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/component/ProviderRefreshEvent.java

@@ -0,0 +1,15 @@
+package cn.tycoding.langchat.aigc.component;
+
+import org.springframework.context.ApplicationEvent;
+
+/**
+ * @author tycoding
+ * @since 2024/6/16
+ */
+public class ProviderRefreshEvent extends ApplicationEvent {
+    private static final long serialVersionUID = 4109980679877560773L;
+
+    public ProviderRefreshEvent(Object source) {
+        super(source);
+    }
+}

+ 16 - 17
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/controller/AigcConversationController.java

@@ -7,11 +7,17 @@ import cn.tycoding.langchat.common.utils.MybatisUtil;
 import cn.tycoding.langchat.common.utils.QueryPage;
 import cn.tycoding.langchat.common.utils.R;
 import cn.tycoding.langchat.common.utils.ServletUtil;
+import java.util.List;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.web.bind.annotation.*;
-
-import java.util.List;
+import org.springframework.web.bind.annotation.DeleteMapping;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PathVariable;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.PutMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
 
 /**
  * @author tycoding
@@ -26,7 +32,7 @@ public class AigcConversationController {
     private final AigcMessageService aigcMessageService;
 
     /**
-     * 会话列表
+     * conversation list, filter by user
      */
     @GetMapping("/list")
     public R conversations() {
@@ -34,36 +40,26 @@ public class AigcConversationController {
     }
 
     /**
-     * 分页数据
+     * conversation page
      */
     @GetMapping("/page")
     public R list(AigcConversation data, QueryPage queryPage) {
         return R.ok(MybatisUtil.getData(aigcMessageService.conversationPages(data, queryPage)));
     }
 
-    /**
-     * 新增会话
-     */
     @PostMapping
     public R addConversation(@RequestBody AigcConversation conversation) {
         return R.ok(aigcMessageService.addConversation(conversation));
     }
-
-    /**
-     * 修改会话
-     */
     @PutMapping
     public R updateConversation(@RequestBody AigcConversation conversation) {
         if (conversation.getId() == null) {
-            return R.fail("参数错误");
+            return R.fail("conversation id is null");
         }
         aigcMessageService.updateConversation(conversation);
         return R.ok();
     }
 
-    /**
-     * 删除会话
-     */
     @DeleteMapping("/{conversationId}")
     public R delConversation(@PathVariable String conversationId) {
         aigcMessageService.delConversation(conversationId);
@@ -77,7 +73,7 @@ public class AigcConversationController {
     }
 
     /**
-     * 获取指定会话下的聊天记录
+     * get messages with conversationId
      */
     @GetMapping("/messages/{conversationId}")
     public R getMessages(@PathVariable String conversationId) {
@@ -85,6 +81,9 @@ public class AigcConversationController {
         return R.ok(list);
     }
 
+    /**
+     * add message in conversation
+     */
     @PostMapping("/message")
     public R addMessage(@RequestBody AigcMessage message) {
         message.setIp(ServletUtil.getIpAddr());

+ 25 - 8
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/controller/AigcModelController.java

@@ -1,7 +1,10 @@
 package cn.tycoding.langchat.aigc.controller;
 
+import cn.hutool.core.util.StrUtil;
+import cn.tycoding.langchat.aigc.component.ProviderRefreshEvent;
 import cn.tycoding.langchat.aigc.entity.AigcModel;
 import cn.tycoding.langchat.aigc.mapper.AigcModelMapper;
+import cn.tycoding.langchat.common.component.SpringContextHolder;
 import cn.tycoding.langchat.common.utils.MybatisUtil;
 import cn.tycoding.langchat.common.utils.QueryPage;
 import cn.tycoding.langchat.common.utils.R;
@@ -27,40 +30,54 @@ import org.springframework.web.bind.annotation.RestController;
 @RequestMapping("/aigc/model")
 public class AigcModelController {
 
-    private final AigcModelMapper docsMapper;
+    private final AigcModelMapper modelMapper;
+    private final SpringContextHolder contextHolder;
 
     @GetMapping("/list")
     public R<List<AigcModel>> list(AigcModel data) {
-        return R.ok(docsMapper.selectList(Wrappers.<AigcModel>lambdaQuery()));
+        List<AigcModel> list = modelMapper.selectList(Wrappers.<AigcModel>lambdaQuery().eq(AigcModel::getProvider, data.getProvider()));
+        list.forEach(this::hide);
+        return R.ok(list);
+    }
+
+    private void hide(AigcModel model) {
+        String key = StrUtil.hide(model.getApiKey(), 5, model.getApiKey().length());
+        model.setApiKey(key);
     }
 
     @GetMapping("/page")
     public R list(AigcModel data, QueryPage queryPage) {
         Page<AigcModel> page = new Page<>(queryPage.getPage(), queryPage.getLimit());
-        return R.ok(MybatisUtil.getData(docsMapper.selectPage(page, Wrappers.<AigcModel>lambdaQuery()
-        )));
+        Page<AigcModel> iPage = modelMapper.selectPage(page, Wrappers.<AigcModel>lambdaQuery().eq(AigcModel::getProvider, data.getProvider()));
+        iPage.getRecords().forEach(this::hide);
+        return R.ok(MybatisUtil.getData(iPage));
     }
 
     @GetMapping("/{id}")
     public R<AigcModel> findById(@PathVariable String id) {
-        return R.ok(docsMapper.selectById(id));
+        return R.ok(modelMapper.selectById(id));
     }
 
     @PostMapping
     public R add(@RequestBody AigcModel data) {
-        docsMapper.insert(data);
+        modelMapper.insert(data);
+        SpringContextHolder.publishEvent(new ProviderRefreshEvent(data));
         return R.ok();
     }
 
     @PutMapping
     public R update(@RequestBody AigcModel data) {
-        docsMapper.updateById(data);
+        modelMapper.updateById(data);
+        SpringContextHolder.publishEvent(new ProviderRefreshEvent(data));
         return R.ok();
     }
 
     @DeleteMapping("/{id}")
     public R delete(@PathVariable String id) {
-        docsMapper.deleteById(id);
+        modelMapper.deleteById(id);
+
+        // Delete dynamically registered beans, according to ID
+        contextHolder.unregisterBean(id);
         return R.ok();
     }
 }

+ 1 - 0
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/entity/AigcModel.java

@@ -29,5 +29,6 @@ public class AigcModel implements Serializable {
     private Double temperature;
     private Double topP;
     private String apiKey;
+    private String baseUrl;
 }
 

+ 26 - 2
langchat-common/src/main/java/cn/tycoding/langchat/common/component/SpringContextHolder.java

@@ -1,6 +1,9 @@
 package cn.tycoding.langchat.common.component;
 
 import org.springframework.beans.BeansException;
+import org.springframework.beans.factory.config.BeanDefinition;
+import org.springframework.beans.factory.support.BeanDefinitionBuilder;
+import org.springframework.beans.factory.support.BeanDefinitionRegistry;
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.ApplicationContextAware;
 import org.springframework.context.ApplicationEvent;
@@ -22,12 +25,33 @@ public class SpringContextHolder implements ApplicationContextAware {
         applicationContext.publishEvent(event);
     }
 
+    public static <T> T getBean(Class<T> requiredType) {
+        return applicationContext.getBean(requiredType);
+    }
+
     @Override
     public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
         SpringContextHolder.applicationContext = applicationContext;
     }
 
-    public static <T> T getBean(Class<T> requiredType) {
-        return applicationContext.getBean(requiredType);
+    public void registerBean(String beanName, Object beanInstance) {
+        BeanDefinitionRegistry beanDefinitionRegistry =
+                (BeanDefinitionRegistry) applicationContext.getAutowireCapableBeanFactory();
+
+        BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder
+                .genericBeanDefinition((Class<Object>) beanInstance.getClass(), () -> beanInstance);
+
+        BeanDefinition beanDefinition = beanDefinitionBuilder.getRawBeanDefinition();
+
+        beanDefinitionRegistry.registerBeanDefinition(beanName, beanDefinition);
+    }
+
+    public void unregisterBean(String beanName) {
+        BeanDefinitionRegistry beanDefinitionRegistry =
+                (BeanDefinitionRegistry) applicationContext.getAutowireCapableBeanFactory();
+
+        if (beanDefinitionRegistry.containsBeanDefinition(beanName)) {
+            beanDefinitionRegistry.removeBeanDefinition(beanName);
+        }
     }
 }

+ 30 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/consts/ProviderEnum.java

@@ -0,0 +1,30 @@
+package cn.tycoding.langchat.core.consts;
+
+import lombok.Getter;
+import lombok.Setter;
+
+/**
+ * @author tycoding
+ * @since 2024/6/16
+ */
+public enum ProviderEnum {
+
+    OPENAI("openai"),
+    GOOGLE("google"),
+    OLLAMA("ollama"),
+    BAIDU("baidu"),
+    ALIBABA("alibaba"),
+    ;
+
+    @Setter
+    @Getter
+    private String model;
+
+    @Setter
+    @Getter
+    private String streamClass;
+
+    ProviderEnum(String model) {
+        this.model = model;
+    }
+}

+ 4 - 3
langchat-core/src/main/java/cn/tycoding/langchat/core/properties/chat/OpenaiProps.java

@@ -2,19 +2,20 @@ package cn.tycoding.langchat.core.properties.chat;
 
 import cn.tycoding.langchat.core.properties.image.OpenaiImageProps;
 import dev.langchain4j.model.Tokenizer;
-import lombok.Data;
-import org.springframework.boot.context.properties.ConfigurationProperties;
-
 import java.net.Proxy;
 import java.time.Duration;
 import java.util.List;
 import java.util.Map;
+import lombok.Data;
+import lombok.experimental.Accessors;
+import org.springframework.boot.context.properties.ConfigurationProperties;
 
 /**
  * @author tycoding
  * @since 2024/4/15
  */
 @Data
+@Accessors(chain = true)
 @ConfigurationProperties(prefix = "langchat.openai")
 public class OpenaiProps {
 

+ 74 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/ProviderInitialize.java

@@ -0,0 +1,74 @@
+package cn.tycoding.langchat.core.provider;
+
+import cn.hutool.core.bean.BeanUtil;
+import cn.hutool.core.util.StrUtil;
+import cn.tycoding.langchat.aigc.entity.AigcModel;
+import cn.tycoding.langchat.aigc.service.AigcModelService;
+import cn.tycoding.langchat.common.component.SpringContextHolder;
+import cn.tycoding.langchat.core.consts.ProviderEnum;
+import cn.tycoding.langchat.core.properties.chat.OpenaiProps;
+import com.alibaba.fastjson2.JSON;
+import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
+import java.util.List;
+import lombok.AllArgsConstructor;
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.ApplicationContextAware;
+import org.springframework.context.annotation.Configuration;
+
+/**
+ * @author tycoding
+ * @since 2024/6/16
+ */
+@Configuration
+@AllArgsConstructor
+public class ProviderInitialize implements ApplicationContextAware {
+
+    private final AigcModelService aigcModelService;
+    private final SpringContextHolder contextHolder;
+
+    @Override
+    public void setApplicationContext(ApplicationContext context) throws BeansException {
+        init();
+    }
+
+    public void init() {
+        List<AigcModel> list = aigcModelService.list();
+        list.forEach(model -> {
+            String provider = model.getProvider();
+            if (ProviderEnum.OPENAI.getModel().equals(provider)) {
+                if (StrUtil.isBlank(model.getApiKey())) {
+                    return;
+                }
+                OpenaiProps props = new OpenaiProps()
+                        .setApiKey(model.getApiKey())
+                        .setBaseUrl(model.getBaseUrl())
+                        .setModelName(model.getModel())
+                        .setMaxTokens(model.getResponseLimit())
+                        .setTemperature(model.getTemperature())
+                        .setTopP(model.getTopP());
+                OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder builder =
+                        JSON.parseObject(JSON.toJSONString(props), OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder.class);
+                BeanUtil.copyProperties(props, builder);
+                OpenAiStreamingChatModel build = builder.build();
+                contextHolder.registerBean(model.getId(), build);
+            }
+
+            if (ProviderEnum.GOOGLE.getModel().equals(provider)) {
+
+            }
+
+            if (ProviderEnum.OLLAMA.getModel().equals(provider)) {
+
+            }
+
+            if (ProviderEnum.BAIDU.getModel().equals(provider)) {
+
+            }
+
+            if (ProviderEnum.ALIBABA.getModel().equals(provider)) {
+
+            }
+        });
+    }
+}

+ 25 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/ProviderListener.java

@@ -0,0 +1,25 @@
+package cn.tycoding.langchat.core.provider;
+
+import cn.tycoding.langchat.aigc.component.ProviderRefreshEvent;
+import lombok.AllArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.context.event.EventListener;
+import org.springframework.stereotype.Component;
+
+/**
+ * @author tycoding
+ * @since 2024/6/16
+ */
+@Slf4j
+@Component
+@AllArgsConstructor
+public class ProviderListener {
+
+    private final ProviderInitialize providerInitialize;
+
+    @EventListener
+    public void providerEvent(ProviderRefreshEvent event) {
+        log.info("refresh provider beans");
+        providerInitialize.init();
+    }
+}

+ 6 - 0
langchat-ui/src/styles/index.less

@@ -112,3 +112,9 @@ body {
 .n-slider-handle-indicator {
   padding: 0 5px !important;
 }
+
+.model-menu {
+  .n-menu-item-content-header {
+    font-weight: 600 !important;
+  }
+}

+ 7 - 13
langchat-ui/src/views/aigc/chat/components/Header.vue

@@ -1,9 +1,9 @@
-<script setup lang="ts">
-  import { modelList } from '@/api/models';
+<script lang="ts" setup>
   import SvgIcon from '@/components/SvgIcon/index.vue';
   import { useChatStore } from '@/views/aigc/chat/components/store/useChatStore';
   import { useDialog, useMessage } from 'naive-ui';
   import { clean } from '@/api/aigc/chat';
+  import ModelProvider from '@/views/aigc/common/ModelProvider.vue';
 
   defineProps<{
     title: string;
@@ -37,26 +37,20 @@
     </div>
     <n-space align="center">
       <n-tag
-        checkable
         v-model:checked="chatStore.isGoogleSearch"
         :bordered="false"
-        type="primary"
+        checkable
         class="border"
+        type="primary"
       >
         <div class="text-sm flex items-center gap-1">
           <SvgIcon icon="devicon:google" />
           <div>Google Search</div>
         </div>
       </n-tag>
-      <n-select
-        size="small"
-        v-model:value="chatStore.model"
-        :options="modelList"
-        :consistent-menu-width="false"
-        class="!w-32"
-      />
+      <ModelProvider />
 
-      <n-button @click="handleClear" size="small" type="warning" secondary>
+      <n-button secondary size="small" type="warning" @click="handleClear">
         <template #icon>
           <SvgIcon class="text-[14px]" icon="fluent:delete-12-regular" />
         </template>
@@ -66,4 +60,4 @@
   </div>
 </template>
 
-<style scoped lang="less"></style>
+<style lang="less" scoped></style>

+ 43 - 0
langchat-ui/src/views/aigc/common/ModelProvider.vue

@@ -0,0 +1,43 @@
+<script lang="ts" setup>
+  import { useChatStore } from '@/views/aigc/chat/components/store/useChatStore';
+  import { onMounted } from 'vue';
+  import { list } from '@/api/aigc/model';
+  import { LLMProviders } from '@/views/aigc/model/data';
+  import { ref } from 'vue-demi';
+
+  const chatStore = useChatStore();
+  const modelList = ref([]);
+
+  onMounted(async () => {
+    const providers = await list({});
+    const data: any = [];
+    LLMProviders.forEach((i) => {
+      const children = providers.filter((m) => m.provider == i.model);
+      if (children.length === 0) {
+        return;
+      }
+      console.log(children);
+      data.push({
+        type: 'group',
+        name: i.name,
+        id: i.id,
+        children: children,
+      });
+    });
+    modelList.value = data;
+  });
+</script>
+
+<template>
+  <n-select
+    v-model:value="chatStore.model"
+    :consistent-menu-width="false"
+    :label-field="'name'"
+    :options="modelList"
+    :value-field="'id'"
+    class="!w-32"
+    size="small"
+  />
+</template>
+
+<style lang="less" scoped></style>

+ 1 - 1
langchat-ui/src/views/aigc/model/data.ts

@@ -25,7 +25,7 @@ export const LLMProviders: any[] = [
 
 export const columns = [
   {
-    title: '供应商',
+    title: '模型别名',
     key: 'name',
   },
   {

+ 64 - 54
langchat-ui/src/views/aigc/model/index.vue

@@ -1,14 +1,17 @@
 <script lang="ts" setup>
-  import { h, reactive, ref, toRaw } from 'vue';
+  import { h, reactive, ref } from 'vue';
   import { BasicTable, TableAction } from '@/components/Table';
   import { DeleteOutlined, EditOutlined, PlusOutlined } from '@vicons/antd';
   import { columns, LLMProviders } from './data';
   import Edit from './edit.vue';
-  import { list } from '@/api/aigc/model';
+  import { del, list } from '@/api/aigc/model';
+  import { useDialog, useMessage } from 'naive-ui';
 
+  const message = useMessage();
+  const dialog = useDialog();
   const actionRef = ref();
   const editRef = ref();
-  const expands = ref([]);
+  const provider = ref('');
   const actionColumn = reactive({
     width: 100,
     title: '操作',
@@ -16,19 +19,6 @@
     fixed: 'right',
     align: 'center',
     render(record: any) {
-      const providers = LLMProviders.map((i) => i.model);
-      if (providers.includes(toRaw(record).model)) {
-        return h(TableAction as any, {
-          style: 'text',
-          actions: [
-            {
-              type: 'success',
-              icon: PlusOutlined,
-              onClick: handleAdd.bind(null, { provider: record.model }),
-            },
-          ],
-        });
-      }
       return h(TableAction as any, {
         style: 'text',
         actions: [
@@ -48,23 +38,14 @@
   });
 
   const loadDataTable = async (params: any) => {
-    const models = await list({ ...params });
-    const data: any[] = [];
-    LLMProviders.forEach((i) => {
-      const children = models.filter((m) => m.provider == i.model);
-      console.log(children);
-      data.push({
-        model: i.model,
-        name: i.name,
-        type: 'expand',
-        expandable: true,
-        children: children,
-      });
-    });
-    return data;
+    if (provider.value === '') {
+      provider.value = LLMProviders[0].model;
+    }
+
+    return await list({ ...params, provider: provider.value });
   };
-  function handleAdd(record: any) {
-    editRef.value.show(record);
+  function handleAdd() {
+    editRef.value.show({ provider: provider.value });
   }
 
   function handleEdit(record: any) {
@@ -75,37 +56,66 @@
     actionRef.value.reload();
   }
 
-  function handleDel(record: any) {}
+  function handleDel(record: any) {
+    dialog.warning({
+      title: '警告',
+      content: `你确定删除 [${record.name}] 模型吗?删除之后不可再用该模型对话`,
+      positiveText: '确定',
+      negativeText: '不确定',
+      onPositiveClick: async () => {
+        await del(record.id);
+        message.success('模型删除成功');
+      },
+    });
+  }
 </script>
 
 <template>
   <div>
     <div class="n-layout-page-header">
-      <n-card :bordered="false" title="模型配置"> 支持对常见的模型配置。 </n-card>
+      <n-card :bordered="false" title="模型配置">
+        支持动态配置LLM大模型参数,支持不同的模型使用不同的ApiKey。
+      </n-card>
     </div>
 
     <n-card :bordered="false" class="mt-4">
-      <BasicTable
-        ref="actionRef"
-        :actionColumn="actionColumn"
-        :columns="columns"
-        :pagination="false"
-        :request="loadDataTable"
-        :row-key="(row:any) => row.model"
-        :single-line="false"
-        default-expand-all
-      >
-        <template #tableTitle>
-          <n-button type="primary" @click="handleAdd">
-            <template #icon>
-              <n-icon>
-                <PlusOutlined />
-              </n-icon>
+      <div class="flex gap-5">
+        <div class="w-52 flex flex-col gap-4 py-1">
+          <div class="font-bold text-base">LLM Provider</div>
+          <n-menu
+            v-model:value="provider"
+            :key-field="'model'"
+            :label-field="'name'"
+            :options="LLMProviders"
+            class="model-menu"
+            @update:value="reloadTable"
+          />
+        </div>
+
+        <div class="w-full">
+          <BasicTable
+            ref="actionRef"
+            :actionColumn="actionColumn"
+            :columns="columns"
+            :pagination="false"
+            :request="loadDataTable"
+            :row-key="(row:any) => row.model"
+            :single-line="false"
+            default-expand-all
+          >
+            <template #tableTitle>
+              <n-button type="primary" @click="handleAdd">
+                <template #icon>
+                  <n-icon>
+                    <PlusOutlined />
+                  </n-icon>
+                </template>
+                新增模型
+              </n-button>
             </template>
-            新增模型
-          </n-button>
-        </template>
-      </BasicTable>
+          </BasicTable>
+        </div>
+      </div>
     </n-card>
 
     <Edit ref="editRef" @reload="reloadTable" />