Browse Source

Merge pull request #7 from GBHBY/embedding_thread_pool

fix async analysis doc
TyCoding 11 tháng trước cách đây
mục cha
commit
3e5e6273ce

+ 51 - 0
langchat-common/src/main/java/cn/tycoding/langchat/common/task/TaskManager.java

@@ -0,0 +1,51 @@
+package cn.tycoding.langchat.common.task;
+
+import cn.tycoding.langchat.common.threadpool.AnalysisThreadPool;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Future;
+
+
+/**
+ * @author GB
+ * @desc
+ * @since 2024-08-22
+ */
+public class TaskManager {
+    private static final ConcurrentHashMap<String, List<Future<?>>> TASK_MAP = new ConcurrentHashMap<>();
+
+    /**
+     * 提交任务
+     *
+     * @param id
+     * @param function
+     */
+    public static void submitTask(String id, Callable<?> function) {
+        Future<?> future = AnalysisThreadPool.getThreadPool().submit(function);
+        List<Future<?>> orDefault = TASK_MAP.getOrDefault(id, new ArrayList<>());
+        orDefault.add(future);
+        TASK_MAP.put(id, orDefault);
+    }
+
+    /**
+     * 弹出任务
+     *
+     * @param id
+     * @return
+     */
+    public void popTaskResult(String id) {
+        TASK_MAP.remove(id);
+    }
+
+    public int getCount(String id) {
+        if (TASK_MAP.containsKey(id)) {
+            Collection<?> collection = TASK_MAP.get(id);
+            return collection != null ? collection.size() : 0;
+        }
+        return 0;
+    }
+}

+ 66 - 0
langchat-common/src/main/java/cn/tycoding/langchat/common/threadpool/AnalysisThreadPool.java

@@ -0,0 +1,66 @@
+package cn.tycoding.langchat.common.threadpool;
+
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * @author GB
+ * @desc
+ * @since 2024-08-22
+ */
+public class AnalysisThreadPool {
+    volatile private static ThreadPoolExecutor EXECUTOR = null;
+
+    /**
+     * 根据cpu 数量动态配置核心线程数和最大线程数
+     */
+    private static final int CPU_COUNT = Runtime.getRuntime().availableProcessors();
+    /**
+     * 核心线程数
+     */
+    private static final int CORE_PO0L_SIZE = CPU_COUNT + 1;
+    /**
+     * 最大线程数
+     */
+    private static final int MAX_POOL_SIZE = 2 * CPU_COUNT + 1;
+    /**
+     * 线程队列
+     */
+    private static final int MAX_LIMIT_JOB_SIZE = 1000;
+    /**
+     * 非核心线程存活时间1s
+     */
+    private static final int KEEP_ALIVE = 1;
+
+    public static ThreadPoolExecutor getThreadPool() {
+        if (null == EXECUTOR) {
+            synchronized (AnalysisThreadPool.class) {
+                if (null == EXECUTOR) {
+                    EXECUTOR = new ThreadPoolExecutor(
+                            CORE_PO0L_SIZE,
+                            MAX_POOL_SIZE,
+                            KEEP_ALIVE,
+                            TimeUnit.MICROSECONDS,
+                            new LinkedBlockingDeque<>(MAX_LIMIT_JOB_SIZE),
+                            Executors.defaultThreadFactory(),
+                            new ThreadPoolExecutor.AbortPolicy() {
+                                @Override
+                                public void rejectedExecution(Runnable r, ThreadPoolExecutor e) {
+                                    super.rejectedExecution(r, e);
+                                }
+                            }
+                    );
+                }
+            }
+        }
+        return EXECUTOR;
+    }
+
+    public static void execute(Runnable runable) {
+        getThreadPool().execute(runable);
+
+    }
+
+}

+ 10 - 5
langchat-server/src/main/java/cn/tycoding/langchat/server/endpoint/EmbeddingEndpoint.java

@@ -27,6 +27,7 @@ import cn.tycoding.langchat.biz.service.AigcOssService;
 import cn.tycoding.langchat.common.dto.ChatReq;
 import cn.tycoding.langchat.common.dto.EmbeddingR;
 import cn.tycoding.langchat.common.exception.ServiceException;
+import cn.tycoding.langchat.common.task.TaskManager;
 import cn.tycoding.langchat.common.utils.R;
 import cn.tycoding.langchat.core.consts.EmbedConst;
 import cn.tycoding.langchat.core.service.LangEmbeddingService;
@@ -37,6 +38,8 @@ import lombok.extern.slf4j.Slf4j;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.web.multipart.MultipartFile;
 
+import java.util.concurrent.Executors;
+
 /**
  * @author tycoding
  * @since 2024/4/25
@@ -85,7 +88,8 @@ public class EmbeddingEndpoint {
     @PostMapping("/docs/{knowledgeId}")
     @SaCheckPermission("aigc:embedding:docs")
     public R docs(MultipartFile file, @PathVariable String knowledgeId) {
-        AigcOss oss = aigcOssService.upload(file, String.valueOf(AuthUtil.getUserId()));
+        String userId = String.valueOf(AuthUtil.getUserId());
+        AigcOss oss = aigcOssService.upload(file, userId);
         AigcDocs data = new AigcDocs()
                 .setName(oss.getOriginalFilename())
                 .setSliceStatus(false)
@@ -94,14 +98,14 @@ public class EmbeddingEndpoint {
                 .setType(EmbedConst.ORIGIN_TYPE_UPLOAD)
                 .setKnowledgeId(knowledgeId);
         aigcKnowledgeService.addDocs(data);
-
-        // embedding docs
-        embeddingService.embedDocsSlice(data, oss.getUrl());
+        TaskManager.submitTask(userId,
+                Executors.callable(() -> embeddingService.embedDocsSlice(data, oss.getUrl())));
         return R.ok();
     }
 
     @GetMapping("/re-embed/{docsId}")
     public R reEmbed(@PathVariable String docsId) {
+        String userId = String.valueOf(AuthUtil.getUserId());
         AigcDocs docs = aigcDocsMapper.selectById(docsId);
         if (docs == null) {
             throw new ServiceException("没有查询到文档数据");
@@ -110,7 +114,8 @@ public class EmbeddingEndpoint {
             text(docs);
         }
         if (EmbedConst.ORIGIN_TYPE_UPLOAD.equals(docs.getType())) {
-            embeddingService.embedDocsSlice(docs, docs.getUrl());
+            TaskManager.submitTask(userId,
+                    Executors.callable(() -> embeddingService.embedDocsSlice(docs, docs.getUrl())));
         }
         return R.ok();
     }

+ 0 - 2
langchat-server/src/main/java/cn/tycoding/langchat/server/service/impl/EmbeddingServiceImpl.java

@@ -34,7 +34,6 @@ import dev.langchain4j.store.embedding.filter.Filter;
 import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 
 import java.util.ArrayList;
@@ -58,7 +57,6 @@ public class EmbeddingServiceImpl implements EmbeddingService {
     private final AigcKnowledgeService aigcKnowledgeService;
     private final PgVectorEmbeddingStore embeddingStore;
 
-    @Async
     @Override
     public void embedDocsSlice(AigcDocs data, String url) {
         List<EmbeddingR> list = langEmbeddingService.embeddingDocs(