浏览代码

change AbortController for streaming request

tycoding 1 年之前
父节点
当前提交
6d5b4c2dc0

+ 9 - 0
langchat-auth/src/main/java/cn/tycoding/langchat/auth/service/GlobalExceptionTranslator.java

@@ -34,6 +34,7 @@ import org.springframework.web.bind.annotation.RestControllerAdvice;
 import org.springframework.web.method.annotation.MethodArgumentTypeMismatchException;
 import org.springframework.web.servlet.resource.NoResourceFoundException;
 
+import java.io.IOException;
 import java.nio.file.AccessDeniedException;
 
 /**
@@ -127,6 +128,14 @@ public class GlobalExceptionTranslator {
     @ExceptionHandler({NoResourceFoundException.class})
     @ResponseStatus(HttpStatus.BAD_REQUEST)
     public R handleError(NoResourceFoundException e) {
+        e.printStackTrace();
+        return R.fail(HttpStatus.UNAUTHORIZED);
+    }
+
+    @ExceptionHandler({IOException.class})
+    @ResponseStatus(HttpStatus.BAD_REQUEST)
+    public R handleError(IOException e) {
+        e.printStackTrace();
         return R.fail(HttpStatus.UNAUTHORIZED);
     }
 }

+ 2 - 0
langchat-common/src/main/java/cn/tycoding/langchat/common/dto/ChatReq.java

@@ -23,6 +23,7 @@ import lombok.experimental.Accessors;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.Executor;
 
 /**
  * @author tycoding
@@ -63,4 +64,5 @@ public class ChatReq {
     private Prompt prompt;
 
     private StreamEmitter emitter;
+    private Executor executor;
 }

+ 40 - 11
langchat-common/src/main/java/cn/tycoding/langchat/common/utils/StreamEmitter.java

@@ -18,6 +18,9 @@ package cn.tycoding.langchat.common.utils;
 
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+
 /**
  * @author tycoding
  * @since 2024/1/30
@@ -27,38 +30,64 @@ public class StreamEmitter {
     private final SseEmitter emitter;
 
     public StreamEmitter() {
-        emitter = new SseEmitter(600 * 1000L);
+        emitter = new SseEmitter(5 * 60 * 1000L);
     }
 
-    public StreamEmitter(Long timeout) {
-        emitter = new SseEmitter(timeout);
+    public SseEmitter get() {
+        return emitter;
     }
 
-    public SseEmitter get() {
+    public SseEmitter streaming(final ExecutorService executor, Runnable func) {
+//        ExecutorService executor = Executors.newSingleThreadExecutor();
+
+        emitter.onCompletion(() -> {
+            System.out.println("SseEmitter 完成");
+            executor.shutdownNow();
+        });
+
+        emitter.onError((e) -> {
+            System.out.println("SseEmitter 出现错误: " + e.getMessage());
+            executor.shutdownNow();
+        });
+
+        emitter.onTimeout(() -> {
+            System.out.println("SseEmitter 超时");
+            emitter.complete();
+            executor.shutdownNow();
+        });
+        executor.execute(() -> {
+            try {
+                func.run();
+            } catch (Exception e) {
+                System.out.println("捕获到异常: " + e.getMessage());
+                emitter.completeWithError(e);
+                Thread.currentThread().interrupt();
+            } finally {
+                if (!executor.isShutdown()) {
+                    executor.shutdownNow();
+                }
+            }
+        });
         return emitter;
     }
 
     public void send(Object obj) {
         try {
             emitter.send(obj);
-        } catch (Exception e) {
+        } catch (IOException e) {
             throw new RuntimeException(e.getMessage());
         }
     }
 
     public void complete() {
-        try {
-            emitter.complete();
-        } catch (Exception e) {
-            throw new RuntimeException(e.getMessage());
-        }
+        emitter.complete();
     }
 
     public void error(String message) {
         try {
             emitter.send("Error: " + message);
             emitter.complete();
-        } catch (Exception e) {
+        } catch (IOException e) {
             throw new RuntimeException(e.getMessage());
         }
     }

+ 8 - 1
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangChatServiceImpl.java

@@ -33,8 +33,10 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel;
 import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.image.ImageModel;
 import dev.langchain4j.model.output.Response;
+import dev.langchain4j.rag.DefaultRetrievalAugmentor;
 import dev.langchain4j.rag.content.retriever.ContentRetriever;
 import dev.langchain4j.rag.query.Query;
+import dev.langchain4j.rag.query.router.DefaultQueryRouter;
 import dev.langchain4j.service.AiServices;
 import dev.langchain4j.service.TokenStream;
 import dev.langchain4j.store.embedding.filter.Filter;
@@ -102,8 +104,13 @@ public class LangChatServiceImpl implements LangChatService {
                     .dynamicFilter(filter)
                     .build();
             aiServices.contentRetriever(contentRetriever);
+            aiServices.retrievalAugmentor(DefaultRetrievalAugmentor
+                    .builder()
+                    .contentRetriever(contentRetriever)
+                    .queryRouter(new DefaultQueryRouter())
+                    .executor(req.getExecutor())
+                    .build());
         }
-
         Agent agent = aiServices.build();
         return agent.stream(req.getConversationId(), req.getMessage());
     }

+ 10 - 4
langchat-server/src/main/java/cn/tycoding/langchat/server/endpoint/ChatEndpoint.java

@@ -35,16 +35,20 @@ import dev.langchain4j.data.message.ChatMessage;
 import dev.langchain4j.data.message.SystemMessage;
 import dev.langchain4j.data.message.UserMessage;
 import lombok.AllArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 
 /**
  * @author tycoding
  * @since 2024/1/30
  */
+@Slf4j
 @RequestMapping("/aigc")
 @RestController
 @AllArgsConstructor
@@ -60,11 +64,13 @@ public class ChatEndpoint {
     public SseEmitter chat(@RequestBody ChatReq req) {
         StreamEmitter emitter = new StreamEmitter();
         req.setEmitter(emitter);
-        req.setUserId(String.valueOf(AuthUtil.getUserId()));
+        req.setUserId(AuthUtil.getUserId());
         req.setUsername(AuthUtil.getUsername());
-
-        chatService.chat(req);
-        return emitter.get();
+        ExecutorService executor = Executors.newSingleThreadExecutor();
+        req.setExecutor(executor);
+        return emitter.streaming(executor, () -> {
+            chatService.chat(req);
+        });
     }
 
     @GetMapping("/app/info")

+ 2 - 1
langchat-server/src/main/java/cn/tycoding/langchat/server/service/impl/ChatServiceImpl.java

@@ -70,7 +70,8 @@ public class ChatServiceImpl implements ChatService {
         saveMessage(req, 0, 0);
 
         try {
-            langChatService.chat(req)
+            langChatService
+                    .chat(req)
                     .onNext(e -> {
                         text.append(e);
                         emitter.send(new ChatRes(e));

+ 6 - 1
langchat-ui/src/api/aigc/chat.ts

@@ -17,12 +17,17 @@
 import { http } from '@/utils/http/axios';
 import { AxiosProgressEvent } from 'axios';
 
-export function chat(data: any, onDownloadProgress?: (progressEvent: AxiosProgressEvent) => void) {
+export function chat(
+  data: any,
+  controller: AbortController,
+  onDownloadProgress?: (progressEvent: AxiosProgressEvent) => void
+) {
   return http.request(
     {
       method: 'post',
       url: '/aigc/chat/completions',
       data,
+      signal: controller.signal,
       onDownloadProgress: onDownloadProgress,
     },
     {

+ 2 - 1
langchat-ui/src/utils/http/axios/index.ts

@@ -247,7 +247,8 @@ const transform: AxiosTransform = {
         return Promise.reject(error);
       }
     } catch (error) {
-      throw new Error(error as any);
+      console.log(error);
+      // throw new Error(error as any);
     }
     // 请求是否被取消
     const isCancel = axios.isCancel(error);

+ 3 - 0
langchat-ui/src/views/chat/Chat.vue

@@ -104,7 +104,9 @@
           modelName: chatStore.modelName,
           modelProvider: chatStore.modelProvider,
         },
+        controller,
         async ({ event }) => {
+          console.log('消息:', event);
           const list = event.target.responseText.split('\n\n');
 
           let text = '';
@@ -161,6 +163,7 @@
   function handleStop() {
     if (loading.value) {
       controller.abort();
+      controller = new AbortController();
       loading.value = false;
     }
   }

+ 1 - 0
langchat-upms/src/main/java/cn/tycoding/langchat/upms/utils/AuthUtil.java

@@ -87,6 +87,7 @@ public class AuthUtil {
         try {
             return (UserInfo) StpUtil.getSession().get(CacheConst.AUTH_USER_INFO_KEY);
         } catch (Exception e) {
+            e.printStackTrace();
             throw new AuthException(403, "登录已失效,请重新登陆");
         }
     }