Forráskód Böngészése

add chat sources info for knowledges

tycoding 1 éve
szülő
commit
eec359ea63

+ 0 - 1
langchat-client/src/main/java/cn/tycoding/langchat/client/service/impl/ClientChatServiceImpl.java

@@ -53,7 +53,6 @@ public class ClientChatServiceImpl implements ClientChatService {
         long startTime = System.currentTimeMillis();
         StringBuilder text = new StringBuilder();
 
-        // save user message
         req.setRole(RoleEnum.USER.getName());
         saveMessage(req, 0, 0);
 

+ 12 - 0
langchat-common/src/main/java/cn/tycoding/langchat/common/dto/ChatRes.java

@@ -19,6 +19,9 @@ package cn.tycoding.langchat.common.dto;
 import lombok.Data;
 import lombok.experimental.Accessors;
 
+import java.util.HashMap;
+import java.util.Map;
+
 /**
  * @author tycoding
  * @since 2024/1/29
@@ -35,6 +38,8 @@ public class ChatRes {
 
     private long time;
 
+    private Map<String, Object> metadata = new HashMap<>();
+
     public ChatRes(String message) {
         this.message = message;
     }
@@ -44,4 +49,11 @@ public class ChatRes {
         this.usedToken = usedToken;
         this.time = System.currentTimeMillis() - startTime;
     }
+
+    public ChatRes(Integer usedToken, long startTime, Map<String, Object> metadata) {
+        this.isDone = true;
+        this.usedToken = usedToken;
+        this.time = System.currentTimeMillis() - startTime;
+        this.metadata = metadata;
+    }
 }

+ 210 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/EmbeddingStoreContentRetrieverCustom.java

@@ -0,0 +1,210 @@
+/*
+ * Copyright (c) 2024 LangChat. TyCoding All Rights Reserved.
+ *
+ * Licensed under the GNU Affero General Public License, Version 3 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     https://www.gnu.org/licenses/agpl-3.0.html
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package cn.tycoding.langchat.core.service.impl;
+
+import dev.langchain4j.data.document.Metadata;
+import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.rag.content.Content;
+import dev.langchain4j.rag.content.retriever.ContentRetriever;
+import dev.langchain4j.rag.query.Query;
+import dev.langchain4j.spi.model.embedding.EmbeddingModelFactory;
+import dev.langchain4j.store.embedding.EmbeddingMatch;
+import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
+import dev.langchain4j.store.embedding.EmbeddingSearchResult;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.filter.Filter;
+import lombok.Builder;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+import static dev.langchain4j.internal.Utils.getOrDefault;
+import static dev.langchain4j.internal.ValidationUtils.*;
+import static dev.langchain4j.spi.ServiceHelper.loadFactories;
+import static java.util.stream.Collectors.toList;
+
+/**
+ * @author tycoding
+ * @since 2024/8/16
+ */
+public class EmbeddingStoreContentRetrieverCustom implements ContentRetriever {
+
+
+    public static final Function<Query, Integer> DEFAULT_MAX_RESULTS = (query) -> 3;
+    public static final Function<Query, Double> DEFAULT_MIN_SCORE = (query) -> 0.0;
+    public static final Function<Query, Filter> DEFAULT_FILTER = (query) -> null;
+
+    public static final String DEFAULT_DISPLAY_NAME = "Default";
+    private static final Map<Object, List<EmbeddingMatch<TextSegment>>> sourceMap = new HashMap<>();
+    private final EmbeddingStore<TextSegment> embeddingStore;
+    private final EmbeddingModel embeddingModel;
+    private final Function<Query, Integer> maxResultsProvider;
+    private final Function<Query, Double> minScoreProvider;
+    private final Function<Query, Filter> filterProvider;
+    private final String displayName;
+    private final Object memoryId;
+
+    public EmbeddingStoreContentRetrieverCustom(Object memoryId,
+                                                EmbeddingStore<TextSegment> embeddingStore,
+                                                EmbeddingModel embeddingModel) {
+        this(
+                memoryId,
+                DEFAULT_DISPLAY_NAME,
+                embeddingStore,
+                embeddingModel,
+                DEFAULT_MAX_RESULTS,
+                DEFAULT_MIN_SCORE,
+                DEFAULT_FILTER
+        );
+    }
+
+    public EmbeddingStoreContentRetrieverCustom(Object memoryId,
+                                                EmbeddingStore<TextSegment> embeddingStore,
+                                                EmbeddingModel embeddingModel,
+                                                int maxResults) {
+        this(
+                memoryId,
+                DEFAULT_DISPLAY_NAME,
+                embeddingStore,
+                embeddingModel,
+                (query) -> maxResults,
+                DEFAULT_MIN_SCORE,
+                DEFAULT_FILTER
+        );
+    }
+
+    public EmbeddingStoreContentRetrieverCustom(Object memoryId,
+                                                EmbeddingStore<TextSegment> embeddingStore,
+                                                EmbeddingModel embeddingModel,
+                                                Integer maxResults,
+                                                Double minScore) {
+        this(
+                memoryId,
+                DEFAULT_DISPLAY_NAME,
+                embeddingStore,
+                embeddingModel,
+                (query) -> maxResults,
+                (query) -> minScore,
+                DEFAULT_FILTER
+        );
+    }
+
+    @Builder
+    private EmbeddingStoreContentRetrieverCustom(Object memoryId,
+                                                 String displayName,
+                                                 EmbeddingStore<TextSegment> embeddingStore,
+                                                 EmbeddingModel embeddingModel,
+                                                 Function<Query, Integer> dynamicMaxResults,
+                                                 Function<Query, Double> dynamicMinScore,
+                                                 Function<Query, Filter> dynamicFilter) {
+        this.memoryId = memoryId;
+        this.displayName = getOrDefault(displayName, DEFAULT_DISPLAY_NAME);
+        this.embeddingStore = ensureNotNull(embeddingStore, "embeddingStore");
+        this.embeddingModel = ensureNotNull(
+                getOrDefault(embeddingModel, EmbeddingStoreContentRetrieverCustom::loadEmbeddingModel),
+                "embeddingModel"
+        );
+        this.maxResultsProvider = getOrDefault(dynamicMaxResults, DEFAULT_MAX_RESULTS);
+        this.minScoreProvider = getOrDefault(dynamicMinScore, DEFAULT_MIN_SCORE);
+        this.filterProvider = getOrDefault(dynamicFilter, DEFAULT_FILTER);
+    }
+
+    private static EmbeddingModel loadEmbeddingModel() {
+        Collection<EmbeddingModelFactory> factories = loadFactories(EmbeddingModelFactory.class);
+        if (factories.size() > 1) {
+            throw new RuntimeException("Conflict: multiple embedding models have been found in the classpath. " +
+                    "Please explicitly specify the one you wish to use.");
+        }
+
+        for (EmbeddingModelFactory factory : factories) {
+            return factory.create();
+        }
+
+        return null;
+    }
+
+    /**
+     * Creates an instance of an {@code EmbeddingStoreContentRetrieverCustom} from the specified {@link EmbeddingStore}
+     * and {@link EmbeddingModel} found through SPI (see {@link EmbeddingModelFactory}).
+     */
+    public static EmbeddingStoreContentRetrieverCustom from(EmbeddingStore<TextSegment> embeddingStore) {
+        return builder().embeddingStore(embeddingStore).build();
+    }
+
+    public static Metadata getMetadata(String memoryId) {
+        List<EmbeddingMatch<TextSegment>> sources = sourceMap.get(memoryId);
+        if (sources == null || sources.isEmpty()) {
+            return null;
+        }
+        return sources.stream().findFirst().get().embedded().metadata();
+    }
+
+    @Override
+    public List<Content> retrieve(Query query) {
+
+        Embedding embeddedQuery = embeddingModel.embed(query.text()).content();
+
+        EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
+                .queryEmbedding(embeddedQuery)
+                .maxResults(maxResultsProvider.apply(query))
+                .minScore(minScoreProvider.apply(query))
+                .filter(filterProvider.apply(query))
+                .build();
+
+        EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest);
+        sourceMap.put(memoryId, searchResult.matches());
+        return searchResult.matches().stream()
+                .map(EmbeddingMatch::embedded)
+                .map(Content::from)
+                .collect(toList());
+    }
+
+    @Override
+    public String toString() {
+        return "EmbeddingStoreContentRetrieverCustom{" +
+                "displayName='" + displayName + '\'' +
+                '}';
+    }
+
+    public static class EmbeddingStoreContentRetrieverCustomBuilder {
+        public EmbeddingStoreContentRetrieverCustomBuilder maxResults(Integer maxResults) {
+            if (maxResults != null) {
+                dynamicMaxResults = (query) -> ensureGreaterThanZero(maxResults, "maxResults");
+            }
+            return this;
+        }
+
+        public EmbeddingStoreContentRetrieverCustomBuilder minScore(Double minScore) {
+            if (minScore != null) {
+                dynamicMinScore = (query) -> ensureBetween(minScore, 0, 1, "minScore");
+            }
+            return this;
+        }
+
+        public EmbeddingStoreContentRetrieverCustomBuilder filter(Filter filter) {
+            if (filter != null) {
+                dynamicFilter = (query) -> filter;
+            }
+            return this;
+        }
+    }
+}

+ 3 - 3
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangChatServiceImpl.java

@@ -35,7 +35,6 @@ import dev.langchain4j.model.output.Response;
 import dev.langchain4j.rag.DefaultRetrievalAugmentor;
 import dev.langchain4j.rag.RetrievalAugmentor;
 import dev.langchain4j.rag.content.retriever.ContentRetriever;
-import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
 import dev.langchain4j.rag.content.retriever.WebSearchContentRetriever;
 import dev.langchain4j.rag.query.Query;
 import dev.langchain4j.rag.query.router.DefaultQueryRouter;
@@ -104,9 +103,10 @@ public class LangChatServiceImpl implements LangChatService {
             req.getKnowledgeIds().add(req.getKnowledgeId());
         }
 
-        if (StrUtil.isNotBlank(req.getKnowledgeId())) {
+        if (req.getKnowledgeIds() != null && !req.getKnowledgeIds().isEmpty()) {
             Function<Query, Filter> filter = (query) -> metadataKey(KNOWLEDGE).isIn(req.getKnowledgeIds());
-            ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
+            ContentRetriever contentRetriever = EmbeddingStoreContentRetrieverCustom.builder()
+                    .memoryId(req.getConversationId())
                     .embeddingStore(embeddingStore)
                     .embeddingModel(embeddingModel)
                     .dynamicFilter(filter)

+ 12 - 10
langchat-server/src/main/java/cn/tycoding/langchat/server/service/impl/ChatServiceImpl.java

@@ -24,11 +24,12 @@ import cn.tycoding.langchat.biz.service.AigcMessageService;
 import cn.tycoding.langchat.common.constant.RoleEnum;
 import cn.tycoding.langchat.common.dto.ChatReq;
 import cn.tycoding.langchat.common.dto.ChatRes;
-import cn.tycoding.langchat.common.exception.ServiceException;
 import cn.tycoding.langchat.common.utils.ServletUtil;
 import cn.tycoding.langchat.common.utils.StreamEmitter;
 import cn.tycoding.langchat.core.service.LangChatService;
+import cn.tycoding.langchat.core.service.impl.EmbeddingStoreContentRetrieverCustom;
 import cn.tycoding.langchat.server.service.ChatService;
+import dev.langchain4j.data.document.Metadata;
 import dev.langchain4j.model.output.TokenUsage;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
@@ -56,16 +57,13 @@ public class ChatServiceImpl implements ChatService {
 
         if (StrUtil.isNotBlank(req.getAppId())) {
             AigcApp app = appStore.get(req.getAppId());
-            if (app == null) {
-                throw new ServiceException("没有配置应用信息");
+            if (app != null) {
+                req.setModelId(app.getModelId());
+                req.setPromptText(app.getPrompt());
+                req.setKnowledgeIds(app.getKnowledgeIds());
             }
-            req.setModelId(app.getModelId());
-            req.setPromptText(app.getPrompt());
-            req.setKnowledgeIds(app.getKnowledgeIds());
         }
-//        req.setPrompt(PromptUtil.build(req.getMessage(), req.getPromptText()));
 
-        // save user message
         req.setRole(RoleEnum.USER.getName());
         saveMessage(req, 0, 0);
 
@@ -77,10 +75,14 @@ public class ChatServiceImpl implements ChatService {
                     })
                     .onComplete((e) -> {
                         TokenUsage tokenUsage = e.tokenUsage();
-                        emitter.send(new ChatRes(tokenUsage.totalTokenCount(), startTime));
+                        Metadata metadata = EmbeddingStoreContentRetrieverCustom.getMetadata(req.getConversationId());
+                        ChatRes res = new ChatRes(tokenUsage.totalTokenCount(), startTime);
+                        if (metadata != null) {
+                            res.setMetadata(metadata.toMap());
+                        }
+                        emitter.send(res);
                         emitter.complete();
 
-                        // save message
                         if (req.getConversationId() != null) {
                             req.setMessage(text.toString());
                             req.setRole(RoleEnum.ASSISTANT.getName());

+ 38 - 16
langchat-ui/src/views/chat/Chat.vue

@@ -73,6 +73,7 @@
       return;
     }
     controller = new AbortController();
+    chatStore.metadata = null;
 
     // user
     chatId.value = uuidv4();
@@ -121,8 +122,11 @@
               return;
             }
 
-            const { done, message } = JSON.parse(i.substring(5, i.length));
+            const { done, message, metadata } = JSON.parse(i.substring(5, i.length));
             if (done || message === null) {
+              if (metadata != null && metadata != {}) {
+                chatStore.metadata = metadata;
+              }
               return;
             }
             text += message;
@@ -191,6 +195,7 @@
           <Message
             v-for="(item, index) of dataSources"
             :key="index"
+            :class="dataSources.length - 1 == index ? '!mb-2' : ''"
             :date-time="item.createTime"
             :error="item.isError"
             :inversion="item.role !== 'assistant'"
@@ -198,13 +203,22 @@
             :text="item.message"
             @delete="handleDelete(item)"
           />
-          <div class="sticky bottom-0 left-0 flex justify-center">
-            <NButton v-if="loading" type="warning" @click="handleStop">
-              <template #icon>
-                <SvgIcon icon="ri:stop-circle-line" />
-              </template>
-              停止响应
-            </NButton>
+          <div v-if="chatStore.metadata != null && chatStore.metadata.length != 0" class="w-fit">
+            <div
+              class="bg-[#f4f6f8] rounded-lg p-2 flex px-4 text-[12px] items-center gap-1 text-gray-500"
+            >
+              <SvgIcon class="text-blue-500 text-[14px]" icon="mingcute:document-2-fill" />
+              <span>引用知识库:</span>
+              <div class="flex items-center gap-2">
+                <span
+                  v-for="meta in chatStore.metadata"
+                  :key="meta"
+                  class="hover:bg-gray-200 cursor-pointer rounded-lg"
+                >
+                  {{ meta.docsName }}
+                </span>
+              </div>
+            </div>
           </div>
         </div>
       </div>
@@ -217,18 +231,30 @@
             ref="inputRef"
             v-model:value="message"
             :autosize="{ minRows: 1, maxRows: isMobile ? 1 : 4 }"
-            class="!rounded-full px-2 py-1"
+            class="!rounded-full px-2 py-1 custom-input"
             placeholder="今天想聊些什么~"
             size="large"
             type="textarea"
             @keypress="handleEnter"
           >
             <template #suffix>
-              <n-button :loading="loading" text @click="handleSubmit">
+              <n-button
+                v-if="!loading"
+                class="!cursor-pointer"
+                size="large"
+                text
+                @click="handleSubmit"
+              >
                 <template #icon>
-                  <n-icon :component="SparklesOutline" />
+                  <n-icon :component="SparklesOutline" class="!cursor-pointer" />
                 </template>
               </n-button>
+              <div v-if="loading" class="!cursor-pointer" @click="handleStop">
+                <SvgIcon
+                  class="!text-3xl hover:text-gray-500 !cursor-pointer"
+                  icon="ri:stop-circle-line"
+                />
+              </div>
             </template>
           </n-input>
         </div>
@@ -240,11 +266,7 @@
 <style lang="less" scoped>
   ::v-deep(.custom-input) {
     .n-input-wrapper {
-      padding-right: 10px;
-    }
-    .n-input__suffix {
-      align-items: end;
-      padding-bottom: 6px;
+      padding-right: 6px !important;
     }
   }
 </style>

+ 12 - 12
langchat-ui/src/views/chat/message/Message.vue

@@ -14,7 +14,7 @@
   - limitations under the License.
   -->
 
-<script setup lang="ts">
+<script lang="ts" setup>
   import { computed, ref } from 'vue';
   import { useMessage } from 'naive-ui';
   import TextComponent from './TextComponent.vue';
@@ -103,33 +103,33 @@
 <template>
   <div
     ref="messageRef"
-    class="flex w-full mb-6 overflow-hidden"
     :class="[{ 'flex-row-reverse': inversion }]"
+    class="flex w-full mb-6 overflow-hidden"
   >
     <div
-      class="flex items-center justify-center flex-shrink-0 h-8 overflow-hidden rounded-full basis-8"
       :class="[inversion ? 'ml-2' : 'mr-2']"
+      class="flex items-center justify-center flex-shrink-0 h-8 overflow-hidden rounded-full basis-8"
     >
       <AvatarComponent :image="inversion" />
     </div>
-    <div class="overflow-hidden text-sm" :class="[inversion ? 'items-end' : 'items-start']">
-      <p class="text-xs text-[#b4bbc4]" :class="[inversion ? 'text-right' : 'text-left']">
+    <div :class="[inversion ? 'items-end' : 'items-start']" class="overflow-hidden text-sm">
+      <p :class="[inversion ? 'text-right' : 'text-left']" class="text-xs text-[#b4bbc4]">
         {{ dateTime }}
       </p>
-      <div class="flex items-end gap-1 mt-2" :class="[inversion ? 'flex-row-reverse' : 'flex-row']">
+      <div :class="[inversion ? 'flex-row-reverse' : 'flex-row']" class="flex items-end gap-1 mt-2">
         <TextComponent
           ref="textRef"
-          :inversion="inversion"
+          :as-raw-text="asRawText"
           :error="error"
-          :text="text"
+          :inversion="inversion"
           :loading="loading"
-          :as-raw-text="asRawText"
+          :text="text"
         />
         <div class="flex flex-col">
           <NDropdown
-            :trigger="isMobile ? 'click' : 'hover'"
-            :placement="!inversion ? 'right' : 'left'"
             :options="options"
+            :placement="!inversion ? 'right' : 'left'"
+            :trigger="isMobile ? 'click' : 'hover'"
             @select="handleSelect"
           >
             <button class="transition text-neutral-300 hover:text-neutral-800">
@@ -142,4 +142,4 @@
   </div>
 </template>
 
-<style scoped lang="less"></style>
+<style lang="less" scoped></style>

+ 1 - 0
langchat-ui/src/views/chat/store/chat.d.ts

@@ -21,5 +21,6 @@ export interface ChatState {
   modelProvider: string | null;
   conversationId: string | null;
   appId: any;
+  metadata: any;
   isGoogleSearch: boolean;
 }

+ 1 - 0
langchat-ui/src/views/chat/store/useChatStore.ts

@@ -26,6 +26,7 @@ export const useChatStore = defineStore('chat-store', {
       modelProvider: '',
       conversationId: null,
       messages: [],
+      metadata: null,
       appId: null,
       isGoogleSearch: false,
     },