Forráskód Böngészése

fix chat memory provider

tycoding 1 éve
szülő
commit
52c49a4bf2

+ 1 - 3
langchat-app/src/main/java/cn/tycoding/langchat/app/endpoint/AppApiChatEndpoint.java

@@ -28,7 +28,6 @@ import cn.tycoding.langchat.app.store.AppChannelStore;
 import cn.tycoding.langchat.app.store.AppStore;
 import cn.tycoding.langchat.common.dto.ChatReq;
 import cn.tycoding.langchat.common.exception.ServiceException;
-import cn.tycoding.langchat.common.utils.PromptUtil;
 import cn.tycoding.langchat.common.utils.StreamEmitter;
 import cn.tycoding.langchat.core.service.LangChatService;
 import lombok.RequiredArgsConstructor;
@@ -65,7 +64,7 @@ public class AppApiChatEndpoint {
 
     private SseEmitter handler(StreamEmitter emitter, String appId, String modelId, List<CompletionReq.Message> messages) {
         if (messages == null || messages.isEmpty() || StrUtil.isBlank(modelId)) {
-            throw new RuntimeException("Message is undefined. Or check the model configuration");
+            throw new RuntimeException("聊天消息为空,或者没有配置模型信息");
         }
         CompletionReq.Message message = messages.get(0);
         ChatReq req = new ChatReq()
@@ -85,7 +84,6 @@ public class AppApiChatEndpoint {
                 req.setKnowledgeIds(app.getKnowledgeIds());
             }
         }
-        req.setPrompt(PromptUtil.build(message.getContent(), req.getPromptText()));
 
         langChatService
                 .singleChat(req)

+ 0 - 8
langchat-client/src/main/java/cn/tycoding/langchat/client/controller/ClientChatEndpoint.java

@@ -18,7 +18,6 @@ package cn.tycoding.langchat.client.controller;
 
 import cn.hutool.core.lang.Dict;
 import cn.hutool.core.lang.UUID;
-import cn.hutool.core.util.StrUtil;
 import cn.tycoding.langchat.biz.entity.AigcOss;
 import cn.tycoding.langchat.biz.service.AigcOssService;
 import cn.tycoding.langchat.biz.utils.ClientAuthUtil;
@@ -60,13 +59,6 @@ public class ClientChatEndpoint {
         req.setEmitter(emitter);
         req.setUserId(ClientAuthUtil.getUserId());
         req.setUsername(ClientAuthUtil.getUsername());
-
-        if (StrUtil.isBlank(req.getPromptText())) {
-            req.setPrompt(PromptUtil.build(req.getMessage()));
-        } else {
-            req.setPrompt(PromptUtil.build(req.getMessage(), req.getPromptText()));
-        }
-
         clientChatService.chat(req);
         return emitter.get();
     }

+ 2 - 4
langchat-client/src/main/java/cn/tycoding/langchat/client/service/impl/ClientChatServiceImpl.java

@@ -72,8 +72,7 @@ public class ClientChatServiceImpl implements ClientChatService {
                         if (req.getConversationId() != null) {
                             req.setMessage(text.toString());
                             req.setRole(RoleEnum.ASSISTANT.getName());
-                            saveMessage(req, tokenUsage.inputTokenCount(),
-                                    tokenUsage.outputTokenCount());
+                            saveMessage(req, tokenUsage.inputTokenCount(), tokenUsage.outputTokenCount());
                         }
                     })
                     .onError((e) -> {
@@ -113,8 +112,7 @@ public class ClientChatServiceImpl implements ClientChatService {
                         if (req.getConversationId() != null) {
                             req.setMessage(text.toString());
                             req.setRole(RoleEnum.ASSISTANT.getName());
-                            saveMessage(req, tokenUsage.inputTokenCount(),
-                                    tokenUsage.outputTokenCount());
+                            saveMessage(req, tokenUsage.inputTokenCount(), tokenUsage.outputTokenCount());
                         }
                     })
                     .onError((e) -> {

+ 21 - 6
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangChatServiceImpl.java

@@ -74,10 +74,14 @@ public class LangChatServiceImpl implements LangChatService {
         if (StrUtil.isBlank(req.getConversationId())) {
             req.setConversationId(IdUtil.simpleUUID());
         }
-
         AiServices<Agent> aiServices = AiServices.builder(Agent.class)
                 .streamingChatLanguageModel(model)
-                .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(20));
+                .systemMessageProvider(memoryId -> req.getPromptText())
+                .chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
+                        .id(req.getConversationId())
+                        .chatMemoryStore(new PersistentChatMemoryStore())
+                        .maxMessages(20)
+                        .build());
 
         EmbeddingModel embeddingModel = embeddingProvider.embed();
 
@@ -111,7 +115,7 @@ public class LangChatServiceImpl implements LangChatService {
         }
 
         Agent agent = aiServices.build();
-        return agent.stream(req.getConversationId(), req.getPrompt().text());
+        return agent.stream(req.getConversationId(), req.getMessage());
     }
 
     @Override
@@ -123,8 +127,14 @@ public class LangChatServiceImpl implements LangChatService {
 
         Agent agent = AiServices.builder(Agent.class)
                 .streamingChatLanguageModel(model)
+                .systemMessageProvider(memoryId -> req.getPromptText())
+                .chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
+                        .id(req.getConversationId())
+                        .chatMemoryStore(new PersistentChatMemoryStore())
+                        .maxMessages(20)
+                        .build())
                 .build();
-        return agent.stream(req.getConversationId(), req.getPrompt().text());
+        return agent.stream(req.getConversationId(), req.getMessage());
     }
 
     @Override
@@ -138,11 +148,16 @@ public class LangChatServiceImpl implements LangChatService {
             StreamingChatLanguageModel model = provider.stream(req.getModelId());
             Agent agent = AiServices.builder(Agent.class)
                     .streamingChatLanguageModel(model)
-                    .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(20))
+                    .systemMessageProvider(memoryId -> req.getPromptText())
+                    .chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
+                            .id(req.getConversationId())
+                            .chatMemoryStore(new PersistentChatMemoryStore())
+                            .maxMessages(20)
+                            .build())
                     .build();
 
             StringBuilder text = new StringBuilder();
-            agent.stream(req.getConversationId(), req.getPrompt().text())
+            agent.stream(req.getConversationId(), req.getMessage())
                     .onNext(text::append)
                     .onComplete((t) -> {
                         future.complete(null);

+ 53 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/PersistentChatMemoryStore.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2024 LangChat. TyCoding All Rights Reserved.
+ *
+ * Licensed under the GNU Affero General Public License, Version 3 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     https://www.gnu.org/licenses/agpl-3.0.html
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package cn.tycoding.langchat.core.service.impl;
+
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author tycoding
+ * @since 2024/8/15
+ */
+public class PersistentChatMemoryStore implements ChatMemoryStore {
+
+    private static final Map<Object, List<ChatMessage>> store = new HashMap<>();
+
+    @Override
+    public List<ChatMessage> getMessages(Object memoryId) {
+        List<ChatMessage> list = store.get(memoryId);
+        if (list == null) {
+            return new ArrayList<>();
+        }
+        return list;
+    }
+
+    @Override
+    public void updateMessages(Object memoryId, List<ChatMessage> messages) {
+        store.put(memoryId, messages);
+    }
+
+    @Override
+    public void deleteMessages(Object memoryId) {
+        store.remove(memoryId);
+    }
+}

+ 2 - 4
langchat-server/src/main/java/cn/tycoding/langchat/server/service/impl/ChatServiceImpl.java

@@ -25,7 +25,6 @@ import cn.tycoding.langchat.common.constant.RoleEnum;
 import cn.tycoding.langchat.common.dto.ChatReq;
 import cn.tycoding.langchat.common.dto.ChatRes;
 import cn.tycoding.langchat.common.exception.ServiceException;
-import cn.tycoding.langchat.common.utils.PromptUtil;
 import cn.tycoding.langchat.common.utils.ServletUtil;
 import cn.tycoding.langchat.common.utils.StreamEmitter;
 import cn.tycoding.langchat.core.service.LangChatService;
@@ -64,7 +63,7 @@ public class ChatServiceImpl implements ChatService {
             req.setPromptText(app.getPrompt());
             req.setKnowledgeIds(app.getKnowledgeIds());
         }
-        req.setPrompt(PromptUtil.build(req.getMessage(), req.getPromptText()));
+//        req.setPrompt(PromptUtil.build(req.getMessage(), req.getPromptText()));
 
         // save user message
         req.setRole(RoleEnum.USER.getName());
@@ -85,8 +84,7 @@ public class ChatServiceImpl implements ChatService {
                         if (req.getConversationId() != null) {
                             req.setMessage(text.toString());
                             req.setRole(RoleEnum.ASSISTANT.getName());
-                            saveMessage(req, tokenUsage.inputTokenCount(),
-                                    tokenUsage.outputTokenCount());
+                            saveMessage(req, tokenUsage.inputTokenCount(), tokenUsage.outputTokenCount());
                         }
                     })
                     .onError((e) -> {

+ 2 - 2
langchat-ui/src/views/aigc/message/index.vue

@@ -66,7 +66,7 @@
   function handleDelete(record: Recordable) {
     dialog.info({
       title: '提示',
-      content: `您想删除 ${record.title}`,
+      content: `您确定删除这条数据吗`,
       positiveText: '确定',
       negativeText: '取消',
       onPositiveClick: async () => {
@@ -85,7 +85,7 @@
 
 <template>
   <n-card :bordered="false">
-    <n-tabs type="line" animated>
+    <n-tabs animated type="line">
       <n-tab-pane name="1" tab="会话消息列表">
         <div class="mt-2">
           <BasicForm @register="register" @reset="handleReset" @submit="reloadTable" />