瀏覽代碼

update: 实现对embedding model & store 动态配置

tycoding 9 月之前
父節點
當前提交
2d78145694
共有 31 個文件被更改,包括 471 次插入153 次删除
  1. 1 1
      docs/langchat.sql
  2. 8 0
      langchat-biz/src/main/java/cn/tycoding/langchat/biz/entity/AigcKnowledge.java
  3. 1 1
      langchat-biz/src/main/java/cn/tycoding/langchat/biz/entity/AigcModel.java
  4. 73 0
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/EmbeddingProvider.java
  5. 1 0
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/EmbeddingStoreInitialize.java
  6. 72 0
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/KnowledgeStore.java
  7. 5 14
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/ProviderInitialize.java
  8. 2 3
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/ModelBuildHandler.java
  9. 3 6
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/OllamaModelBuildHandler.java
  10. 4 6
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/OpenAIModelBuildHandler.java
  11. 3 6
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/QFanModelBuildHandler.java
  12. 3 6
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/QWenModelBuildHandler.java
  13. 3 6
      langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/ZhipuModelBuildHandler.java
  14. 2 5
      langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangChatServiceImpl.java
  15. 10 8
      langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangEmbeddingServiceImpl.java
  16. 35 5
      langchat-server/src/main/java/cn/tycoding/langchat/server/controller/AigcKnowledgeController.java
  17. 22 14
      langchat-server/src/main/java/cn/tycoding/langchat/server/service/impl/EmbeddingServiceImpl.java
  18. 6 6
      langchat-ui/src/components/Table/src/Table.vue
  19. 2 0
      langchat-ui/src/components/Table/src/hooks/useDataSource.ts
  20. 1 1
      langchat-ui/src/components/Table/src/hooks/usePagination.ts
  21. 19 6
      langchat-ui/src/views/aigc/embed-store/columns.ts
  22. 7 2
      langchat-ui/src/views/aigc/embed-store/edit.vue
  23. 5 0
      langchat-ui/src/views/aigc/embed-store/index.vue
  24. 17 2
      langchat-ui/src/views/aigc/knowledge/columns.ts
  25. 0 1
      langchat-ui/src/views/aigc/knowledge/components/DocsList/columns.ts
  26. 58 40
      langchat-ui/src/views/aigc/knowledge/components/index.vue
  27. 42 3
      langchat-ui/src/views/aigc/knowledge/edit.vue
  28. 20 6
      langchat-ui/src/views/aigc/knowledge/index.vue
  29. 23 3
      langchat-ui/src/views/aigc/model/components/embedding/columns.ts
  30. 2 2
      langchat-ui/src/views/aigc/model/components/embedding/index.vue
  31. 21 0
      langchat-ui/src/views/aigc/model/components/embedding/schemas.ts

+ 1 - 1
docs/langchat.sql

@@ -166,7 +166,7 @@ CREATE TABLE `aigc_model` (
                               `image_size` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '图片大小',
                               `image_quality` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '图片质量',
                               `image_style` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '图片风格',
-                              `dimensions` int DEFAULT NULL COMMENT '向量维数',
+                              `dimension` int DEFAULT NULL COMMENT '向量维数',
                               PRIMARY KEY (`id`)
 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='LLM模型配置表';
 

+ 8 - 0
langchat-biz/src/main/java/cn/tycoding/langchat/biz/entity/AigcKnowledge.java

@@ -41,6 +41,9 @@ public class AigcKnowledge implements Serializable {
     @TableId(type = IdType.ASSIGN_UUID)
     private String id;
 
+    private String embedStoreId;
+    private String embedModelId;
+
     /**
      * 知识库名称
      */
@@ -67,5 +70,10 @@ public class AigcKnowledge implements Serializable {
     private Long totalSize = 0L;
     @TableField(exist = false)
     private List<AigcDocs> docs = new ArrayList<>();
+
+    @TableField(exist = false)
+    private AigcEmbedStore embedStore;
+    @TableField(exist = false)
+    private AigcModel embedModel;
 }
 

+ 1 - 1
langchat-biz/src/main/java/cn/tycoding/langchat/biz/entity/AigcModel.java

@@ -52,6 +52,6 @@ public class AigcModel implements Serializable {
     private String imageSize;
     private String imageQuality;
     private String imageStyle;
-    private Integer dimensions;
+    private Integer dimension;
 }
 

+ 73 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/EmbeddingProvider.java

@@ -16,17 +16,25 @@
 
 package cn.tycoding.langchat.core.provider;
 
+import cn.tycoding.langchat.biz.entity.AigcKnowledge;
+import cn.tycoding.langchat.common.exception.ServiceException;
 import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.consts.ProviderEnum;
 import dev.langchain4j.data.document.DocumentSplitter;
 import dev.langchain4j.data.document.splitter.DocumentSplitters;
+import dev.langchain4j.data.segment.TextSegment;
 import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.openai.OpenAiTokenizer;
+import dev.langchain4j.store.embedding.EmbeddingStore;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.context.ApplicationContext;
 import org.springframework.stereotype.Component;
 
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+
 /**
  * @author tycoding
  * @since 2024/3/8
@@ -60,4 +68,69 @@ public class EmbeddingProvider {
         }
         return null;
     }
+
+    public EmbeddingModel getEmbeddingModel(List<String> knowledgeIds) {
+        List<String> storeIds = new ArrayList<>();
+        knowledgeIds.forEach(id -> {
+            if (context.containsBean(id)) {
+                AigcKnowledge data = (AigcKnowledge) context.getBean(id);
+                if (data.getEmbedModelId() != null) {
+                    storeIds.add(data.getEmbedModelId());
+                }
+            }
+        });
+        if (storeIds.isEmpty()) {
+            throw new ServiceException("知识库缺少Embedding Model配置,请先检查配置");
+        }
+
+        HashSet<String> filterIds = new HashSet<>(storeIds);
+        if (filterIds.size() > 1) {
+            throw new ServiceException("存在多个不同Embedding Model的知识库,请先检查配置");
+        }
+
+        return (EmbeddingModel) context.getBean(storeIds.get(0));
+    }
+
+    public EmbeddingModel getEmbeddingModel(String knowledgeId) {
+        if (context.containsBean(knowledgeId)) {
+            AigcKnowledge data = (AigcKnowledge) context.getBean(knowledgeId);
+            if (context.containsBean(data.getEmbedModelId())) {
+                return (EmbeddingModel) context.getBean(data.getEmbedModelId());
+            }
+        }
+        throw new ServiceException("没有找到匹配的Embedding向量数据库");
+    }
+
+    public EmbeddingStore<TextSegment> getEmbeddingStore(String knowledgeId) {
+        if (context.containsBean(knowledgeId)) {
+            AigcKnowledge data = (AigcKnowledge) context.getBean(knowledgeId);
+            if (context.containsBean(data.getEmbedStoreId())) {
+                return (EmbeddingStore<TextSegment>) context.getBean(data.getEmbedStoreId());
+            }
+        }
+        throw new ServiceException("没有找到匹配的Embedding向量数据库");
+    }
+
+    public EmbeddingStore<TextSegment> getEmbeddingStore(List<String> knowledgeIds) {
+        List<String> storeIds = new ArrayList<>();
+        knowledgeIds.forEach(id -> {
+            if (context.containsBean(id)) {
+                AigcKnowledge data = (AigcKnowledge) context.getBean(id);
+                if (data.getEmbedStoreId() != null) {
+                    storeIds.add(data.getEmbedStoreId());
+                }
+            }
+        });
+        if (storeIds.isEmpty()) {
+            throw new ServiceException("知识库缺少Embedding Store配置,请先检查配置");
+        }
+
+        HashSet<String> filterIds = new HashSet<>(storeIds);
+        if (filterIds.size() > 1) {
+            throw new ServiceException("存在多个不同Embedding Store数据源的知识库,请先检查配置");
+        }
+
+        return (EmbeddingStore<TextSegment>) context.getBean(storeIds.get(0));
+    }
+
 }

+ 1 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/EmbeddingStoreInitialize.java

@@ -84,6 +84,7 @@ public class EmbeddingStoreInitialize implements ApplicationContextAware {
                             .user(embed.getUsername())
                             .password(embed.getPassword())
                             .table(embed.getTableName())
+                            .indexListSize(1)
                             .useIndex(true)
                             .createTable(true)
                             .dropTableFirst(false)

+ 72 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/KnowledgeStore.java

@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2024 LangChat. TyCoding All Rights Reserved.
+ *
+ * Licensed under the GNU Affero General Public License, Version 3 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    https://www.gnu.org/licenses/agpl-3.0.html
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package cn.tycoding.langchat.core.provider;
+
+import cn.tycoding.langchat.biz.entity.AigcEmbedStore;
+import cn.tycoding.langchat.biz.entity.AigcKnowledge;
+import cn.tycoding.langchat.biz.entity.AigcModel;
+import cn.tycoding.langchat.biz.service.AigcEmbedStoreService;
+import cn.tycoding.langchat.biz.service.AigcKnowledgeService;
+import cn.tycoding.langchat.biz.service.AigcModelService;
+import cn.tycoding.langchat.common.component.SpringContextHolder;
+import lombok.AllArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.ApplicationContextAware;
+import org.springframework.stereotype.Component;
+
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * @author tycoding
+ * @since 2024/10/29
+ */
+@Slf4j
+@Component
+@AllArgsConstructor
+public class KnowledgeStore implements ApplicationContextAware {
+
+    private final AigcKnowledgeService knowledgeService;
+    private final SpringContextHolder contextHolder;
+    private final AigcModelService modelService;
+    private final AigcEmbedStoreService embedStoreService;
+
+    @Override
+    public void setApplicationContext(ApplicationContext context) throws BeansException {
+        init();
+    }
+
+    public void init() {
+        List<AigcKnowledge> list = knowledgeService.list();
+        Map<String, List<AigcModel>> modelMap = modelService.list().stream().collect(Collectors.groupingBy(AigcModel::getId));
+        Map<String, List<AigcEmbedStore>> storeMap = embedStoreService.list().stream().collect(Collectors.groupingBy(AigcEmbedStore::getId));
+        list.forEach(know -> {
+            if (know.getEmbedModelId() != null) {
+                List<AigcModel> models = modelMap.get(know.getEmbedModelId());
+                know.setEmbedModel(models == null ? null : models.get(0));
+            }
+            if (know.getEmbedStoreId() != null) {
+                List<AigcEmbedStore> stores = storeMap.get(know.getEmbedStoreId());
+                know.setEmbedStore(stores == null ? null : stores.get(0));
+            }
+            contextHolder.registerBean(know.getId(), know);
+        });
+    }
+}

+ 5 - 14
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/ProviderInitialize.java

@@ -16,18 +16,16 @@
 
 package cn.tycoding.langchat.core.provider;
 
-import cn.hutool.core.lang.Pair;
 import cn.hutool.core.util.ObjectUtil;
 import cn.tycoding.langchat.biz.component.ModelTypeEnum;
 import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.biz.service.AigcModelService;
 import cn.tycoding.langchat.common.component.SpringContextHolder;
-import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.consts.ModelConst;
 import cn.tycoding.langchat.core.provider.build.ModelBuildHandler;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
@@ -61,13 +59,6 @@ public class ProviderInitialize implements ApplicationContextAware {
     public void init() {
         modelStore = new ArrayList<>();
 
-        // un register embedding model
-        contextHolder.unregisterBean(EmbedConst.CLAZZ_NAME_OPENAI);
-        contextHolder.unregisterBean(EmbedConst.CLAZZ_NAME_QIANFAN);
-        contextHolder.unregisterBean(EmbedConst.CLAZZ_NAME_ZHIPU);
-        contextHolder.unregisterBean(EmbedConst.CLAZZ_NAME_QIANWEN);
-        contextHolder.unregisterBean(EmbedConst.CLAZZ_NAME_OLLAMA);
-
         List<AigcModel> list = aigcModelService.list();
         list.forEach(model -> {
             if (Objects.equals(model.getBaseUrl(), "")) {
@@ -81,7 +72,7 @@ public class ProviderInitialize implements ApplicationContextAware {
             imageHandler(model);
         });
 
-        modelStore.forEach(i -> log.info("已成功注册模型:{}, 模型配置:{}", i.getProvider(), i));
+        modelStore.forEach(i -> log.info("已成功注册模型:{} -- {}, 模型配置:{}", i.getProvider(), i.getType(),  i));
     }
 
     private void chatHandler(AigcModel model) {
@@ -114,9 +105,9 @@ public class ProviderInitialize implements ApplicationContextAware {
                 return;
             }
             modelBuildHandlers.forEach(x -> {
-                Pair<String, DimensionAwareEmbeddingModel> embeddingModelPair = x.buildEmbedding(model);
-                if (ObjectUtil.isNotEmpty(embeddingModelPair)) {
-                    contextHolder.registerBean(embeddingModelPair.getKey(), embeddingModelPair.getValue());
+                EmbeddingModel embeddingModel = x.buildEmbedding(model);
+                if (ObjectUtil.isNotEmpty(embeddingModel)) {
+                    contextHolder.registerBean(model.getId(), embeddingModel);
                     modelStore.add(model);
                 }
             });

+ 2 - 3
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/ModelBuildHandler.java

@@ -16,11 +16,10 @@
 
 package cn.tycoding.langchat.core.provider.build;
 
-import cn.hutool.core.lang.Pair;
 import cn.tycoding.langchat.biz.entity.AigcModel;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 
 /**
@@ -52,7 +51,7 @@ public interface ModelBuildHandler {
     /**
      * embedding config
      */
-    Pair<String, DimensionAwareEmbeddingModel> buildEmbedding(AigcModel model);
+    EmbeddingModel buildEmbedding(AigcModel model);
 
     /**
      * image config

+ 3 - 6
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/OllamaModelBuildHandler.java

@@ -16,15 +16,13 @@
 
 package cn.tycoding.langchat.core.provider.build;
 
-import cn.hutool.core.lang.Pair;
 import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.common.enums.ChatErrorEnum;
 import cn.tycoding.langchat.common.exception.ServiceException;
-import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.consts.ProviderEnum;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import dev.langchain4j.model.ollama.OllamaChatModel;
 import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
@@ -110,7 +108,7 @@ public class OllamaModelBuildHandler implements ModelBuildHandler {
     }
 
     @Override
-    public Pair<String, DimensionAwareEmbeddingModel> buildEmbedding(AigcModel model) {
+    public EmbeddingModel buildEmbedding(AigcModel model) {
         try {
             if (!whetherCurrentModel(model)) {
                 return null;
@@ -118,14 +116,13 @@ public class OllamaModelBuildHandler implements ModelBuildHandler {
             if (!basicCheck(model)) {
                 return null;
             }
-            OllamaEmbeddingModel ollamaEmbeddingModel = OllamaEmbeddingModel
+            return OllamaEmbeddingModel
                     .builder()
                     .baseUrl(model.getBaseUrl())
                     .modelName(model.getModel())
                     .logRequests(true)
                     .logResponses(true)
                     .build();
-            return Pair.of(EmbedConst.CLAZZ_NAME_OLLAMA, ollamaEmbeddingModel);
         } catch (ServiceException e) {
             log.error(e.getMessage());
             throw e;

+ 4 - 6
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/OpenAIModelBuildHandler.java

@@ -16,17 +16,15 @@
 
 package cn.tycoding.langchat.core.provider.build;
 
-import cn.hutool.core.lang.Pair;
 import cn.hutool.core.util.StrUtil;
 import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.common.enums.ChatErrorEnum;
 import cn.tycoding.langchat.common.exception.ServiceException;
-import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.consts.ProviderEnum;
 import cn.tycoding.langchat.core.properties.LangChatProps;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import dev.langchain4j.model.openai.OpenAiChatModel;
 import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
@@ -134,7 +132,7 @@ public class OpenAIModelBuildHandler implements ModelBuildHandler {
     }
 
     @Override
-    public Pair<String, DimensionAwareEmbeddingModel> buildEmbedding(AigcModel model) {
+    public EmbeddingModel buildEmbedding(AigcModel model) {
         try {
             if (!whetherCurrentModel(model)) {
                 return null;
@@ -147,12 +145,12 @@ public class OpenAIModelBuildHandler implements ModelBuildHandler {
                     .apiKey(model.getApiKey())
                     .baseUrl(model.getBaseUrl())
                     .modelName(model.getModel())
-                    .dimensions(model.getDimensions())
+                    .dimensions(model.getDimension())
                     .logRequests(true)
                     .logResponses(true)
                     .dimensions(1024)
                     .build();
-            return Pair.of(EmbedConst.CLAZZ_NAME_OPENAI, openAiEmbeddingModel);
+            return openAiEmbeddingModel;
         } catch (ServiceException e) {
             log.error(e.getMessage());
             throw e;

+ 3 - 6
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/QFanModelBuildHandler.java

@@ -16,15 +16,13 @@
 
 package cn.tycoding.langchat.core.provider.build;
 
-import cn.hutool.core.lang.Pair;
 import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.common.enums.ChatErrorEnum;
 import cn.tycoding.langchat.common.exception.ServiceException;
-import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.consts.ProviderEnum;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import dev.langchain4j.model.qianfan.QianfanChatModel;
 import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
@@ -120,7 +118,7 @@ public class QFanModelBuildHandler implements ModelBuildHandler {
     }
 
     @Override
-    public Pair<String, DimensionAwareEmbeddingModel> buildEmbedding(AigcModel model) {
+    public EmbeddingModel buildEmbedding(AigcModel model) {
         try {
             if (!whetherCurrentModel(model)) {
                 return null;
@@ -128,7 +126,7 @@ public class QFanModelBuildHandler implements ModelBuildHandler {
             if (!basicCheck(model)) {
                 return null;
             }
-            QianfanEmbeddingModel qianfanEmbeddingModel = QianfanEmbeddingModel
+            return QianfanEmbeddingModel
                     .builder()
                     .apiKey(model.getApiKey())
                     .modelName(model.getModel())
@@ -136,7 +134,6 @@ public class QFanModelBuildHandler implements ModelBuildHandler {
                     .logRequests(true)
                     .logResponses(true)
                     .build();
-            return Pair.of(EmbedConst.CLAZZ_NAME_QIANFAN, qianfanEmbeddingModel);
         } catch (ServiceException e) {
             log.error(e.getMessage());
             throw e;

+ 3 - 6
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/QWenModelBuildHandler.java

@@ -16,18 +16,16 @@
 
 package cn.tycoding.langchat.core.provider.build;
 
-import cn.hutool.core.lang.Pair;
 import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.common.enums.ChatErrorEnum;
 import cn.tycoding.langchat.common.exception.ServiceException;
-import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.consts.ProviderEnum;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
 import dev.langchain4j.model.dashscope.QwenChatModel;
 import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
 import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
-import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.lang3.StringUtils;
@@ -111,7 +109,7 @@ public class QWenModelBuildHandler implements ModelBuildHandler {
     }
 
     @Override
-    public Pair<String, DimensionAwareEmbeddingModel> buildEmbedding(AigcModel model) {
+    public EmbeddingModel buildEmbedding(AigcModel model) {
         try {
             if (!whetherCurrentModel(model)) {
                 return null;
@@ -119,12 +117,11 @@ public class QWenModelBuildHandler implements ModelBuildHandler {
             if (!basicCheck(model)) {
                 return null;
             }
-            QwenEmbeddingModel qwenEmbeddingModel = QwenEmbeddingModel
+            return QwenEmbeddingModel
                     .builder()
                     .apiKey(model.getApiKey())
                     .modelName(model.getModel())
                     .build();
-            return Pair.of(EmbedConst.CLAZZ_NAME_QIANWEN, qwenEmbeddingModel);
         } catch (ServiceException e) {
             log.error(e.getMessage());
             throw e;

+ 3 - 6
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/build/ZhipuModelBuildHandler.java

@@ -16,16 +16,14 @@
 
 package cn.tycoding.langchat.core.provider.build;
 
-import cn.hutool.core.lang.Pair;
 import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.common.enums.ChatErrorEnum;
 import cn.tycoding.langchat.common.exception.ServiceException;
-import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.consts.ProviderEnum;
 import cn.tycoding.langchat.core.properties.LangChatProps;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
 import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
@@ -132,7 +130,7 @@ public class ZhipuModelBuildHandler implements ModelBuildHandler {
     }
 
     @Override
-    public Pair<String, DimensionAwareEmbeddingModel> buildEmbedding(AigcModel model) {
+    public EmbeddingModel buildEmbedding(AigcModel model) {
         try {
             if (!whetherCurrentModel(model)) {
                 return null;
@@ -140,7 +138,7 @@ public class ZhipuModelBuildHandler implements ModelBuildHandler {
             if (!basicCheck(model)) {
                 return null;
             }
-            ZhipuAiEmbeddingModel zhipuAiEmbeddingModel = ZhipuAiEmbeddingModel
+            return ZhipuAiEmbeddingModel
                     .builder()
                     .apiKey(model.getApiKey())
                     .model(model.getModel())
@@ -153,7 +151,6 @@ public class ZhipuModelBuildHandler implements ModelBuildHandler {
                     .readTimeout(Duration.ofMinutes(2))
                     .dimensions(1024)
                     .build();
-            return Pair.of(EmbedConst.CLAZZ_NAME_ZHIPU, zhipuAiEmbeddingModel);
         } catch (ServiceException e) {
             log.error(e.getMessage());
             throw e;

+ 2 - 5
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangChatServiceImpl.java

@@ -30,7 +30,6 @@ import dev.langchain4j.data.image.Image;
 import dev.langchain4j.memory.chat.MessageWindowChatMemory;
 import dev.langchain4j.model.chat.ChatLanguageModel;
 import dev.langchain4j.model.chat.StreamingChatLanguageModel;
-import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import dev.langchain4j.model.output.Response;
 import dev.langchain4j.rag.DefaultRetrievalAugmentor;
@@ -59,7 +58,6 @@ public class LangChatServiceImpl implements LangChatService {
 
     private final ModelProvider provider;
     private final EmbeddingProvider embeddingProvider;
-//    private final PgVectorEmbeddingStore embeddingStore;
     private final ChatProps chatProps;
 
     private AiServices<Agent> build(StreamingChatLanguageModel streamModel, ChatLanguageModel model, ChatReq req) {
@@ -87,7 +85,6 @@ public class LangChatServiceImpl implements LangChatService {
         }
 
         AiServices<Agent> aiServices = build(model, null, req);
-        EmbeddingModel embeddingModel = embeddingProvider.embed();
 
         if (StrUtil.isNotBlank(req.getKnowledgeId())) {
             req.getKnowledgeIds().add(req.getKnowledgeId());
@@ -97,8 +94,8 @@ public class LangChatServiceImpl implements LangChatService {
             Function<Query, Filter> filter = (query) -> metadataKey(KNOWLEDGE).isIn(req.getKnowledgeIds());
             ContentRetriever contentRetriever = EmbeddingStoreContentRetrieverCustom.builder()
                     .memoryId(req.getConversationId())
-//                    .embeddingStore(embeddingStore)
-                    .embeddingModel(embeddingModel)
+                    .embeddingStore(embeddingProvider.getEmbeddingStore(req.getKnowledgeIds()))
+                    .embeddingModel(embeddingProvider.getEmbeddingModel(req.getKnowledgeIds()))
                     .dynamicFilter(filter)
                     .build();
             aiServices.retrievalAugmentor(DefaultRetrievalAugmentor

+ 10 - 8
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangEmbeddingServiceImpl.java

@@ -28,6 +28,7 @@ import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser
 import dev.langchain4j.data.embedding.Embedding;
 import dev.langchain4j.data.segment.TextSegment;
 import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.store.embedding.EmbeddingStore;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
@@ -49,17 +50,17 @@ import static dev.langchain4j.data.document.Metadata.metadata;
 public class LangEmbeddingServiceImpl implements LangEmbeddingService {
 
     private final EmbeddingProvider embeddingProvider;
-//    private final PgVectorEmbeddingStore embeddingStore;
 
     @Override
     public EmbeddingR embeddingText(ChatReq req) {
         log.info(">>>>>>>>>>>>>> Text文本向量解析开始,KnowledgeId={}, DocsName={}", req.getKnowledgeId(), req.getDocsName());
         TextSegment segment = TextSegment.from(req.getMessage(),
                 metadata(KNOWLEDGE, req.getKnowledgeId()).put(FILENAME, req.getDocsName()));
-        EmbeddingModel embeddingModel = embeddingProvider.embed();
+
+        EmbeddingModel embeddingModel = embeddingProvider.getEmbeddingModel(req.getKnowledgeId());
+        EmbeddingStore<TextSegment> embeddingStore = embeddingProvider.getEmbeddingStore(req.getKnowledgeId());
         Embedding embedding = embeddingModel.embed(segment).content();
-//        String id = embeddingStore.add(embedding, segment);
-        String id = "";
+        String id = embeddingStore.add(embedding, segment);
 
         log.info(">>>>>>>>>>>>>> Text文本向量解析结束,KnowledgeId={}, DocsName={}", req.getKnowledgeId(), req.getDocsName());
         return new EmbeddingR().setVectorId(id).setText(segment.text());
@@ -67,7 +68,6 @@ public class LangEmbeddingServiceImpl implements LangEmbeddingService {
 
     @Override
     public List<EmbeddingR> embeddingDocs(ChatReq req) {
-        EmbeddingModel model = embeddingProvider.embed();
 
         log.info(">>>>>>>>>>>>>> Docs文档向量解析开始,KnowledgeId={}, DocsName={}", req.getKnowledgeId(), req.getDocsName());
         Document document;
@@ -80,9 +80,11 @@ public class LangEmbeddingServiceImpl implements LangEmbeddingService {
 
         DocumentSplitter splitter = EmbeddingProvider.splitter(req.getModelName(), req.getModelProvider());
         List<TextSegment> segments = splitter.split(document);
-        List<Embedding> embeddings = model.embedAll(segments).content();
-//        List<String> ids = embeddingStore.addAll(embeddings, segments);
-        List<String> ids = new ArrayList<>();
+
+        EmbeddingModel embeddingModel = embeddingProvider.getEmbeddingModel(req.getKnowledgeId());
+        EmbeddingStore<TextSegment> embeddingStore = embeddingProvider.getEmbeddingStore(req.getKnowledgeId());
+        List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
+        List<String> ids = embeddingStore.addAll(embeddings, segments);
 
         List<EmbeddingR> list = new ArrayList<>();
         for (int i = 0; i < ids.size(); i++) {

+ 35 - 5
langchat-server/src/main/java/cn/tycoding/langchat/server/controller/AigcKnowledgeController.java

@@ -19,13 +19,19 @@ package cn.tycoding.langchat.server.controller;
 import cn.dev33.satoken.annotation.SaCheckPermission;
 import cn.hutool.core.util.StrUtil;
 import cn.tycoding.langchat.biz.entity.AigcDocs;
+import cn.tycoding.langchat.biz.entity.AigcEmbedStore;
 import cn.tycoding.langchat.biz.entity.AigcKnowledge;
+import cn.tycoding.langchat.biz.entity.AigcModel;
 import cn.tycoding.langchat.biz.mapper.AigcDocsMapper;
+import cn.tycoding.langchat.biz.service.AigcEmbedStoreService;
 import cn.tycoding.langchat.biz.service.AigcKnowledgeService;
+import cn.tycoding.langchat.biz.service.AigcModelService;
 import cn.tycoding.langchat.common.annotation.ApiLog;
 import cn.tycoding.langchat.common.utils.MybatisUtil;
 import cn.tycoding.langchat.common.utils.QueryPage;
 import cn.tycoding.langchat.common.utils.R;
+import cn.tycoding.langchat.core.provider.EmbeddingProvider;
+import cn.tycoding.langchat.core.provider.KnowledgeStore;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
@@ -48,6 +54,10 @@ public class AigcKnowledgeController {
 
     private final AigcKnowledgeService kbService;
     private final AigcDocsMapper docsMapper;
+    private final AigcEmbedStoreService embedStoreService;
+    private final AigcModelService modelService;
+    private final EmbeddingProvider embeddingProvider;
+    private final KnowledgeStore knowledgeStore;
 
     @GetMapping("/list")
     public R<List<AigcKnowledge>> list(AigcKnowledge data) {
@@ -77,12 +87,22 @@ public class AigcKnowledgeController {
                 .orderByDesc(AigcKnowledge::getCreateTime);
         Page<AigcKnowledge> iPage = kbService.page(page, queryWrapper);
 
+        Map<String, List<AigcEmbedStore>> embedStoreMap = embedStoreService.list().stream().collect(Collectors.groupingBy(AigcEmbedStore::getId));
+        Map<String, List<AigcModel>> embedModelMap = modelService.list().stream().collect(Collectors.groupingBy(AigcModel::getId));
         Map<String, List<AigcDocs>> docsMap = docsMapper.selectList(Wrappers.lambdaQuery()).stream().collect(Collectors.groupingBy(AigcDocs::getKnowledgeId));
-        iPage.getRecords().forEach(i -> {
-            List<AigcDocs> docs = docsMap.get(i.getId());
+        iPage.getRecords().forEach(item -> {
+            List<AigcDocs> docs = docsMap.get(item.getId());
             if (docs != null) {
-                i.setDocsNum(docs.size());
-                i.setTotalSize(docs.stream().filter(d -> d.getSize() != null).mapToLong(AigcDocs::getSize).sum());
+                item.setDocsNum(docs.size());
+                item.setTotalSize(docs.stream().filter(d -> d.getSize() != null).mapToLong(AigcDocs::getSize).sum());
+            }
+            if (item.getEmbedModelId() != null) {
+                List<AigcModel> list = embedModelMap.get(item.getEmbedModelId());
+                item.setEmbedModel(list == null ? null : list.get(0));
+            }
+            if (item.getEmbedStoreId() != null) {
+                List<AigcEmbedStore> list = embedStoreMap.get(item.getEmbedStoreId());
+                item.setEmbedStore(list == null ? null : list.get(0));
             }
         });
 
@@ -91,7 +111,14 @@ public class AigcKnowledgeController {
 
     @GetMapping("/{id}")
     public R<AigcKnowledge> findById(@PathVariable String id) {
-        return R.ok(kbService.getById(id));
+        AigcKnowledge knowledge = kbService.getById(id);
+        if (knowledge.getEmbedStoreId() != null) {
+            knowledge.setEmbedStore(embedStoreService.getById(knowledge.getEmbedStoreId()));
+        }
+        if (knowledge.getEmbedModelId() != null) {
+            knowledge.setEmbedModel(modelService.getById(knowledge.getEmbedModelId()));
+        }
+        return R.ok(knowledge);
     }
 
     @PostMapping
@@ -100,6 +127,7 @@ public class AigcKnowledgeController {
     public R add(@RequestBody AigcKnowledge data) {
         data.setCreateTime(String.valueOf(System.currentTimeMillis()));
         kbService.save(data);
+        knowledgeStore.init();
         return R.ok();
     }
 
@@ -108,6 +136,7 @@ public class AigcKnowledgeController {
     @SaCheckPermission("aigc:knowledge:update")
     public R update(@RequestBody AigcKnowledge data) {
         kbService.updateById(data);
+        knowledgeStore.init();
         return R.ok();
     }
 
@@ -116,6 +145,7 @@ public class AigcKnowledgeController {
     @SaCheckPermission("aigc:knowledge:delete")
     public R delete(@PathVariable String id) {
         kbService.removeKnowledge(id);
+        knowledgeStore.init();
         return R.ok();
     }
 }

+ 22 - 14
langchat-server/src/main/java/cn/tycoding/langchat/server/service/impl/EmbeddingServiceImpl.java

@@ -19,6 +19,7 @@ package cn.tycoding.langchat.server.service.impl;
 import cn.hutool.core.util.StrUtil;
 import cn.tycoding.langchat.biz.entity.AigcDocs;
 import cn.tycoding.langchat.biz.entity.AigcDocsSlice;
+import cn.tycoding.langchat.biz.mapper.AigcDocsMapper;
 import cn.tycoding.langchat.biz.service.AigcKnowledgeService;
 import cn.tycoding.langchat.common.dto.ChatReq;
 import cn.tycoding.langchat.common.dto.EmbeddingR;
@@ -26,7 +27,11 @@ import cn.tycoding.langchat.core.provider.EmbeddingProvider;
 import cn.tycoding.langchat.core.service.LangEmbeddingService;
 import cn.tycoding.langchat.server.service.EmbeddingService;
 import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.data.segment.TextSegment;
 import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
+import dev.langchain4j.store.embedding.EmbeddingSearchResult;
+import dev.langchain4j.store.embedding.EmbeddingStore;
 import dev.langchain4j.store.embedding.filter.Filter;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
@@ -52,7 +57,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
     private final EmbeddingProvider embeddingProvider;
     private final LangEmbeddingService langEmbeddingService;
     private final AigcKnowledgeService aigcKnowledgeService;
-//    private final PgVectorEmbeddingStore embeddingStore;
+    private final AigcDocsMapper aigcDocsMapper;
 
     @Override
     @Transactional
@@ -65,7 +70,9 @@ public class EmbeddingServiceImpl implements EmbeddingService {
         if (vectorIds.isEmpty()) {
             return;
         }
-//        embeddingStore.removeAll(vectorIds);
+        AigcDocs docs = aigcDocsMapper.selectById(docsId);
+        EmbeddingStore<TextSegment> embeddingStore = embeddingProvider.getEmbeddingStore(docs.getKnowledgeId());
+        embeddingStore.removeAll(vectorIds);
         // remove from docSlice
         aigcKnowledgeService.removeSlicesOfDoc(docsId);
     }
@@ -96,22 +103,23 @@ public class EmbeddingServiceImpl implements EmbeddingService {
             return List.of();
         }
 
-        EmbeddingModel embeddingModel = embeddingProvider.embed();
+        EmbeddingModel embeddingModel = embeddingProvider.getEmbeddingModel(data.getKnowledgeId());
+        EmbeddingStore<TextSegment> embeddingStore = embeddingProvider.getEmbeddingStore(data.getKnowledgeId());
         Embedding queryEmbedding = embeddingModel.embed(data.getContent()).content();
         Filter filter = metadataKey(KNOWLEDGE).isEqualTo(data.getKnowledgeId());
-//        EmbeddingSearchResult<TextSegment> list = embeddingStore.search(EmbeddingSearchRequest
-//                .builder()
-//                .queryEmbedding(queryEmbedding)
-//                .filter(filter)
-//                .build());
+        EmbeddingSearchResult<TextSegment> list = embeddingStore.search(EmbeddingSearchRequest
+                .builder()
+                .queryEmbedding(queryEmbedding)
+                .filter(filter)
+                .build());
 
         List<Map<String, Object>> result = new ArrayList<>();
-//        list.matches().forEach(i -> {
-//            TextSegment embedded = i.embedded();
-//            Map<String, Object> map = embedded.metadata().toMap();
-//            map.put("text", embedded.text());
-//            result.add(map);
-//        });
+        list.matches().forEach(i -> {
+            TextSegment embedded = i.embedded();
+            Map<String, Object> map = embedded.metadata().toMap();
+            map.put("text", embedded.text());
+            result.add(map);
+        });
         return result;
     }
 }

+ 6 - 6
langchat-ui/src/components/Table/src/Table.vue

@@ -86,17 +86,17 @@
 
 <script lang="ts">
   import {
-    ref,
+    computed,
     defineComponent,
+    nextTick,
+    onMounted,
     reactive,
-    unref,
+    ref,
     toRaw,
-    computed,
     toRefs,
-    onMounted,
-    nextTick,
+    unref,
   } from 'vue';
-  import { ReloadOutlined, ColumnHeightOutlined, QuestionCircleOutlined } from '@vicons/antd';
+  import { ColumnHeightOutlined, QuestionCircleOutlined, ReloadOutlined } from '@vicons/antd';
   import { createTableContext } from './hooks/useTableContext';
 
   import ColumnSetting from './components/settings/ColumnSetting.vue';

+ 2 - 0
langchat-ui/src/components/Table/src/hooks/useDataSource.ts

@@ -64,6 +64,8 @@ export function useDataSource(
         pageParams[sizeField] = pageSize;
       }
 
+      setPagination({ page: pageParams?.[pageField] });
+
       let params = {
         ...pageParams,
       };

+ 1 - 1
langchat-ui/src/components/Table/src/hooks/usePagination.ts

@@ -1,6 +1,6 @@
 import type { PaginationProps } from '../types/pagination';
 import type { BasicTableProps } from '../types/table';
-import { computed, unref, ref, ComputedRef, watch } from 'vue';
+import { computed, ComputedRef, ref, unref, watch } from 'vue';
 
 import { isBoolean } from '@/utils/is';
 import { DEFAULTPAGESIZE, PAGESIZES } from '../const';

+ 19 - 6
langchat-ui/src/views/aigc/embed-store/columns.ts

@@ -65,6 +65,24 @@ export const columns: BasicColumn[] = [
       );
     },
   },
+  {
+    title: '向量纬度',
+    key: 'dimension',
+    align: 'center',
+    width: '80',
+    render(row) {
+      return h(
+        NTag,
+        {
+          size: 'small',
+          type: 'error',
+        },
+        {
+          default: () => row.dimension,
+        }
+      );
+    },
+  },
   {
     title: '数据库地址',
     key: 'host',
@@ -97,12 +115,6 @@ export const columns: BasicColumn[] = [
     key: 'tableName',
     align: 'center',
   },
-  {
-    title: '向量纬度',
-    key: 'dimension',
-    align: 'center',
-    width: '80',
-  },
 ];
 
 export const searchSchemas: FormSchema[] = [
@@ -191,6 +203,7 @@ export function getSchemas(provider: string) {
     label: '向量纬度',
     component: 'NSelect',
     defaultValue: 1024,
+    labelMessage: '慎重修改此参数,纬度高会消耗更多的算力,但纬度高并不代表搜索更精确',
     componentProps: {
       placeholder: '请输入向量纬度',
       options: [

+ 7 - 2
langchat-ui/src/views/aigc/embed-store/edit.vue

@@ -18,7 +18,7 @@
   import { computed, nextTick } from 'vue';
   import { add, getById, update } from '@/api/aigc/embed-store';
   import { useMessage } from 'naive-ui';
-  import { getSchemas } from './columns';
+  import { getProviderLabel, getSchemas } from './columns';
   import { basicModal, useModal } from '@/components/Modal';
   import { BasicForm, useForm } from '@/components/Form';
   import { isNullOrWhitespace } from '@/utils/is';
@@ -32,7 +32,7 @@
     modalRegister,
     { openModal: openModal, closeModal: closeModal, setSubLoading: setSubLoading },
   ] = useModal({
-    title: props.provider + ' 新增/编辑',
+    title: getProviderLabel(props.provider) + ' 新增/编辑',
     closable: true,
     maskClosable: false,
     showCloseBtn: false,
@@ -85,6 +85,11 @@
 <template>
   <basicModal style="width: 45%" @register="modalRegister">
     <template #default>
+      <n-alert
+        class="w-full mb-4 mt-2 min-alert"
+        title="注意:请慎重修改模型的向量纬度参数(Dimension),此参数需要和向量库匹配(错误修改可能将影响已有的向量数据)"
+        type="info"
+      />
       <BasicForm :schemas="schemas" class="mt-5" @register="register" @submit="handleSubmit" />
     </template>
   </basicModal>

+ 5 - 0
langchat-ui/src/views/aigc/embed-store/index.vue

@@ -106,6 +106,11 @@
   <div class="h-full">
     <n-card :bordered="false">
       <BasicForm @register="register" @reset="handleReset" @submit="reloadTable" />
+      <n-alert
+        class="w-full mb-4 mt-2 min-alert"
+        title="注意:请慎重修改模型的向量纬度参数(Dimension),此参数需要和向量库匹配(错误修改可能将影响已有的向量数据)"
+        type="info"
+      />
 
       <BasicTable
         ref="actionRef"

+ 17 - 2
langchat-ui/src/views/aigc/knowledge/columns.ts

@@ -43,6 +43,20 @@ export const formSchemas: FormSchema[] = [
     },
     rules: [{ required: true, message: '请输入知识库名称', trigger: ['blur'] }],
   },
+  {
+    field: 'embedStoreId',
+    label: '向量数据库',
+    component: 'NInput',
+    slot: 'embedStoreSlot',
+    rules: [{ required: true, message: '请选择关联向量数据库', trigger: ['blur'] }],
+  },
+  {
+    field: 'embedModelId',
+    label: '向量模型',
+    component: 'NInput',
+    slot: 'embedModelSlot',
+    rules: [{ required: true, message: '请选择关联向量模型', trigger: ['blur'] }],
+  },
   {
     field: 'des',
     component: 'NInput',
@@ -51,9 +65,10 @@ export const formSchemas: FormSchema[] = [
       placeholder: '请输入知识库描述',
       type: 'textarea',
       autosize: {
-        minRows: 5,
-        maxRows: 8,
+        minRows: 3,
+        maxRows: 3,
       },
     },
+    rules: [{ required: true, message: '请输入知识库描述', trigger: ['blur'] }],
   },
 ];

+ 0 - 1
langchat-ui/src/views/aigc/knowledge/components/DocsList/columns.ts

@@ -120,6 +120,5 @@ export const formSchemas: FormSchema[] = [
         maxRows: 12,
       },
     },
-    rules: [{ required: true, message: '请输入文档内容', trigger: ['blur'] }],
   },
 ];

+ 58 - 40
langchat-ui/src/views/aigc/knowledge/components/index.vue

@@ -17,10 +17,10 @@
 <script lang="ts" setup>
   import DocList from './DocsList/index.vue';
   import DocsSlice from './DocsSlice/index.vue';
+  import SvgIcon from '@/components/SvgIcon/index.vue';
   import DocsSliceSearch from './DocsSliceSearch/index.vue';
   import ImportFile from './ImportFile/index.vue';
   import { onMounted, ref } from 'vue';
-  import type { MenuOption } from 'naive-ui';
   import { NIcon } from 'naive-ui';
   import { useRouter } from 'vue-router';
   import { renderIcon } from '@/utils';
@@ -34,8 +34,7 @@
   import { getById } from '@/api/aigc/knowledge';
 
   const router = useRouter();
-
-  const menu = ref();
+  const active = ref('import-file');
   const menuOptions = ref([
     {
       label: '数据导入',
@@ -53,7 +52,7 @@
   onMounted(async () => {
     const id = router.currentRoute.value.params.id;
     knowledge.value = await getById(String(id));
-    menu.value = menuOptions.value[0].key;
+    active.value = menuOptions.value[0].key;
 
     menuOptions.value.push(
       {
@@ -69,8 +68,8 @@
     );
   });
 
-  function handleSelect(key: string, item: MenuOption) {
-    menu.value = key;
+  function handleSelect(key: string) {
+    active.value = key;
   }
 
   function handleReturn() {
@@ -79,41 +78,60 @@
 </script>
 
 <template>
-  <div>
-    <div class="n-layout-page-header">
-      <n-card :bordered="false" size="medium">
-        <template #header>
-          <n-space class="flex items-center">
-            <n-button dashed type="primary" @click="handleReturn">
-              知识库列表
-              <template #icon>
-                <n-icon>
-                  <ArrowUndoOutline />
-                </n-icon>
-              </template>
-            </n-button>
-            <span>
-              {{ knowledge.name }}
-            </span>
-          </n-space>
-        </template>
+  <div class="mt-2" style="height: calc(100vh - 130px) !important">
+    <n-grid :x-gap="13" class="h-full" cols="2 s:2 m:2 l:24 xl:24 2xl:24" responsive="screen">
+      <n-gi class="bg-white p-4 rounded-md" span="5">
+        <n-button block class="mb-4" dashed size="small" type="primary" @click="handleReturn">
+          知识库列表
+          <template #icon>
+            <n-icon>
+              <ArrowUndoOutline />
+            </n-icon>
+          </template>
+        </n-button>
+
+        <div class="flex items-center gap-2">
+          <div class="relative bg-blue-100 p-2 rounded">
+            <SvgIcon class="text-lg" icon="ep:document" />
+          </div>
+          <span class="font-semibold text-[16px]">{{ knowledge.name }}</span>
+        </div>
+        <div class="text-[13px] text-gray-400 mt-3">{{ knowledge.des }}</div>
+        <n-divider class="my-3" />
+        <div class="my-3 flex flex-col gap-2">
+          <div class="text-xs">知识库ID</div>
+          <n-input v-model:value="knowledge.id" />
+        </div>
+        <div class="my-3 flex flex-col gap-2">
+          <div class="text-xs">关联向量数据库</div>
+          <div v-if="knowledge.embedStore == null" class="py-2 text-gray-400"
+            >没有配置关联向量数据库</div
+          >
+          <n-input v-else v-model:value="knowledge.embedStore.name" />
+        </div>
+        <div class="my-3 flex flex-col gap-2">
+          <div class="text-xs">关联向量化模型</div>
+          <div v-if="knowledge.embedModel == null" class="py-2 text-gray-400"
+            >没有配置关联向量化模型</div
+          >
+          <n-input v-else v-model:value="knowledge.embedModel.name" />
+        </div>
+      </n-gi>
+      <n-gi class="h-full bg-white p-4 overflow-y-auto rounded-md" span="19">
+        <n-tabs v-model:value="active" class="flex items-center mb-6" @update:value="handleSelect">
+          <n-tab v-for="item in menuOptions" :key="item.key" :name="item.key">
+            <component :is="item.icon" />
+            <span class="pl-2 font-bold">{{ item.label }}</span>
+          </n-tab>
+        </n-tabs>
 
-        {{ knowledge.des }}
-      </n-card>
-    </div>
-    <div class="mt-2 h-full mx-4" style="height: calc(100vh - 242px) !important">
-      <n-grid :x-gap="10" class="h-full" cols="2 s:2 m:2 l:24 xl:24 2xl:24" responsive="screen">
-        <n-gi class="bg-white pt-2" span="3">
-          <n-menu v-model:value="menu" :options="menuOptions" @update:value="handleSelect" />
-        </n-gi>
-        <n-gi class="h-full overflow-y-auto" span="21">
-          <DocList v-if="menu == 'doc-list'" />
-          <DocsSlice v-if="menu == 'slice-list'" />
-          <DocsSliceSearch v-if="menu == 'slice-search'" />
-          <ImportFile v-if="menu == 'import-file'" :data="knowledge" />
-        </n-gi>
-      </n-grid>
-    </div>
+        <n-tabs />
+        <DocList v-if="active == 'doc-list'" />
+        <DocsSlice v-if="active == 'slice-list'" />
+        <DocsSliceSearch v-if="active == 'slice-search'" />
+        <ImportFile v-if="active == 'import-file'" :data="knowledge" />
+      </n-gi>
+    </n-grid>
   </div>
 </template>
 

+ 42 - 3
langchat-ui/src/views/aigc/knowledge/edit.vue

@@ -15,16 +15,21 @@
   -->
 
 <script lang="ts" setup>
-  import { nextTick } from 'vue';
+  import { nextTick, ref } from 'vue';
   import { add, getById, update } from '@/api/aigc/knowledge';
+  import { list as getModelStores } from '@/api/aigc/embed-store';
+  import { list as getEmbedModels } from '@/api/aigc/model';
   import { useMessage } from 'naive-ui';
   import { formSchemas } from './columns';
   import { BasicForm, useForm } from '@/components/Form';
   import { basicModal, useModal } from '@/components/Modal';
   import { isNullOrWhitespace } from '@/utils/is';
+  import { ModelTypeEnum } from '@/api/models';
 
   const emit = defineEmits(['reload']);
   const message = useMessage();
+  const embedStoreList = ref([]);
+  const embedModelList = ref([]);
 
   const [modalRegister, { openModal, closeModal }] = useModal({
     title: '新增/编辑知识库',
@@ -43,6 +48,25 @@
 
   async function show(id: string) {
     openModal();
+    const stores = await getModelStores({});
+    if (stores != null) {
+      embedStoreList.value = stores.map((item: any) => {
+        return {
+          label: item.name,
+          value: item.id,
+        };
+      });
+    }
+    const models = await getEmbedModels({ type: ModelTypeEnum.EMBEDDING });
+    if (models != null) {
+      embedModelList.value = models.map((item: any) => {
+        return {
+          label: item.name,
+          value: item.id,
+        };
+      });
+    }
+
     await nextTick();
     if (id) {
       setFieldsValue(await getById(id));
@@ -68,8 +92,23 @@
 </script>
 
 <template>
-  <basicModal style="width: 45%" @register="modalRegister">
-    <BasicForm class="mt-5" @register="register" @submit="handleSubmit" />
+  <basicModal style="width: 35%" @register="modalRegister">
+    <BasicForm class="mt-5" @register="register" @submit="handleSubmit">
+      <template #embedStoreSlot="{ model, field }">
+        <n-select
+          v-model:value="model[field]"
+          :options="embedStoreList"
+          placeholder="请选择关联向量数据库"
+        />
+      </template>
+      <template #embedModelSlot="{ model, field }">
+        <n-select
+          v-model:value="model[field]"
+          :options="embedModelList"
+          placeholder="请选择关联向量化模型"
+        />
+      </template>
+    </BasicForm>
   </basicModal>
 </template>
 

+ 20 - 6
langchat-ui/src/views/aigc/knowledge/index.vue

@@ -105,7 +105,7 @@
     <BasicForm @register="register" @reset="fetch" @submit="fetch" />
 
     <n-spin :show="loading">
-      <div class="grid grid-cols-1 gap-6 md:grid-cols-2 xl:grid-cols-4 mb-8">
+      <div class="grid grid-cols-1 gap-4 md:grid-cols-2 xl:grid-cols-4 mb-8">
         <div
           class="bg-gray-100 py-3 pt-4 transition-all duration-300 px-2 transform border cursor-pointer rounded-xl group"
         >
@@ -124,7 +124,7 @@
         <div
           v-for="item in list"
           :key="item.id"
-          class="bg-white px-4 py-3 pt-4 transition-all duration-300 transform border cursor-pointer rounded-xl hover:border-transparent group hover:shadow-lg"
+          class="bg-white px-4 py-3 pt-4 relative transition-all duration-300 transform border cursor-pointer rounded-xl hover:border-transparent group hover:shadow-lg"
           @click="handleInfo(item)"
         >
           <div class="flex flex-col sm:-mx-4 sm:flex-row">
@@ -134,8 +134,17 @@
               </div>
             </div>
 
-            <div class="pr-4">
-              <h1 class="text-lg font-semibold text-gray-700 capitalize"> {{ item.name }} </h1>
+            <div class="pr-4 w-full">
+              <div class="flex items-center justify-between">
+                <h1 class="text-lg font-semibold text-gray-700 capitalize"> {{ item.name }} </h1>
+                <div
+                  v-if="item.embedModel != null"
+                  class="absolute right-0 px-2 flex items-center gap-1 py-1 bg-gray-200 text-gray-500 rounded-l-md text-xs"
+                >
+                  <SvgIcon icon="octicon:ai-model-24" />
+                  {{ item.embedModel.model }}
+                </div>
+              </div>
 
               <p class="mt-2 text-gray-500 capitalize text-xs">
                 {{ item.des }}
@@ -143,12 +152,17 @@
             </div>
           </div>
 
-          <div class="flex mt-4 -mx-2 px-2 text-gray-400 justify-between items-center">
+          <div class="flex mt-6 px-2 text-gray-400 justify-between items-center">
             <div class="flex items-center gap-1">
               <SvgIcon class="" icon="mdi:tag-outline" />
               <span class="text-xs">文档数:{{ item.docsNum }}</span>
-              <n-divider vertical />
+              <n-divider class="!m-0.5" vertical />
               <span class="text-xs">{{ (Number(item.totalSize) / 1000000).toFixed(2) }} MB</span>
+              <n-divider class="!m-0.5" vertical />
+              <SvgIcon icon="material-symbols:database-outline" />
+              <span v-if="item.embedStore != null" class="!text-xs">
+                {{ item.embedStore.name }}
+              </span>
             </div>
             <div class="flex items-center">
               <n-popselect

+ 23 - 3
langchat-ui/src/views/aigc/model/components/embedding/columns.ts

@@ -14,6 +14,9 @@
  * limitations under the License.
  */
 
+import { h } from 'vue';
+import { NTag } from 'naive-ui';
+
 export const baseColumns = [
   {
     title: '模型别名',
@@ -24,9 +27,26 @@ export const baseColumns = [
     key: 'model',
   },
   {
-    title: '回复上限',
-    key: 'responseLimit',
-    width: '120',
+    title: '向量纬度',
+    key: 'dimension',
+    align: 'center',
+    width: '100',
+    render(row) {
+      return h(
+        NTag,
+        {
+          size: 'small',
+          type: 'error',
+        },
+        {
+          default: () => row.dimension,
+        }
+      );
+    },
+  },
+  {
+    title: 'Base Url',
+    key: 'baseUrl',
   },
 ];
 

+ 2 - 2
langchat-ui/src/views/aigc/model/components/embedding/index.vue

@@ -58,13 +58,13 @@
 
   const columns = computed(() => {
     nextTick();
-    return getColumns(provider.value);
+    return getColumns();
   });
   const loadDataTable = async (params: any) => {
     if (provider.value === '') {
       provider.value = LLMProviders[0].model;
     }
-    return await getModels({ ...params, provider: provider.value, type: ModelTypeEnum.TEXT_IMAGE });
+    return await getModels({ ...params, provider: provider.value, type: ModelTypeEnum.EMBEDDING });
   };
   async function handleAdd() {
     editRef.value.show({ provider: provider.value });

+ 21 - 0
langchat-ui/src/views/aigc/model/components/embedding/schemas.ts

@@ -59,6 +59,27 @@ const baseSchemas: FormSchema[] = [
       placeholder: '请输入ApiKey',
     },
   },
+  {
+    field: 'dimension',
+    label: '向量纬度',
+    component: 'NSelect',
+    defaultValue: 1024,
+    labelMessage: '慎重修改此参数,纬度高会消耗更多的算力,但纬度高并不代表搜索更精确',
+    componentProps: {
+      placeholder: '请输入向量纬度',
+      options: [
+        {
+          label: '1024',
+          value: 1024,
+        },
+        {
+          label: '1536',
+          value: 1536,
+        },
+      ],
+    },
+    rules: [{ type: 'number', required: true, message: '请输入向量纬度', trigger: ['blur'] }],
+  },
 ];
 
 export function getSchemas(provider: string) {