tycoding 11 місяців тому
батько
коміт
679bcdf664

+ 6 - 0
langchat-biz/src/main/java/cn/tycoding/langchat/biz/service/impl/AigcKnowledgeServiceImpl.java

@@ -29,6 +29,7 @@ import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 import java.util.Date;
 import java.util.List;
@@ -46,17 +47,20 @@ public class AigcKnowledgeServiceImpl extends ServiceImpl<AigcKnowledgeMapper, A
     private final AigcDocsSliceMapper aigcDocsSliceMapper;
 
     @Override
+    @Transactional
     public void addDocs(AigcDocs data) {
         data.setCreateTime(new Date());
         aigcDocsMapper.insert(data);
     }
 
     @Override
+    @Transactional
     public void updateDocs(AigcDocs data) {
         aigcDocsMapper.updateById(data);
     }
 
     @Override
+    @Transactional
     public void addDocsSlice(AigcDocsSlice data) {
         data.setCreateTime(new Date())
                 .setWordNum(data.getContent().length())
@@ -66,6 +70,7 @@ public class AigcKnowledgeServiceImpl extends ServiceImpl<AigcKnowledgeMapper, A
     }
 
     @Override
+    @Transactional
     public void updateDocsSlice(AigcDocsSlice data) {
         aigcDocsSliceMapper.updateById(data);
     }
@@ -84,6 +89,7 @@ public class AigcKnowledgeServiceImpl extends ServiceImpl<AigcKnowledgeMapper, A
     }
 
     @Override
+    @Transactional
     public void removeSlicesOfDoc(String docsId) {
         LambdaQueryWrapper<AigcDocsSlice> deleteWrapper = Wrappers.<AigcDocsSlice>lambdaQuery()
                 .eq(AigcDocsSlice::getDocsId, docsId);

+ 6 - 0
langchat-biz/src/main/java/cn/tycoding/langchat/biz/service/impl/AigcMessageServiceImpl.java

@@ -31,6 +31,7 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
 import lombok.RequiredArgsConstructor;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 import java.util.Date;
 import java.util.List;
@@ -89,6 +90,7 @@ public class AigcMessageServiceImpl extends ServiceImpl<AigcMessageMapper, AigcM
     }
 
     @Override
+    @Transactional
     public AigcConversation addConversation(AigcConversation conversation) {
         conversation.setCreateTime(new Date());
         aigcConversationMapper.insert(conversation);
@@ -96,6 +98,7 @@ public class AigcMessageServiceImpl extends ServiceImpl<AigcMessageMapper, AigcM
     }
 
     @Override
+    @Transactional
     public void updateConversation(AigcConversation conversation) {
         aigcConversationMapper.updateById(
                 new AigcConversation().setId(conversation.getId())
@@ -103,6 +106,7 @@ public class AigcMessageServiceImpl extends ServiceImpl<AigcMessageMapper, AigcM
     }
 
     @Override
+    @Transactional
     public void delConversation(String conversationId) {
         aigcConversationMapper.deleteById(conversationId);
         baseMapper.delete(
@@ -111,6 +115,7 @@ public class AigcMessageServiceImpl extends ServiceImpl<AigcMessageMapper, AigcM
     }
 
     @Override
+    @Transactional
     public AigcMessage addMessage(AigcMessage message) {
         message.setCreateTime(new Date());
         baseMapper.insert(message);
@@ -118,6 +123,7 @@ public class AigcMessageServiceImpl extends ServiceImpl<AigcMessageMapper, AigcM
     }
 
     @Override
+    @Transactional
     public void clearMessage(String conversationId) {
         baseMapper.delete(
                 Wrappers.<AigcMessage>lambdaQuery()

+ 1 - 0
langchat-client/src/main/java/cn/tycoding/langchat/client/controller/ClientChatEndpoint.java

@@ -81,6 +81,7 @@ public class ClientChatEndpoint {
         AigcOss oss = aigcOssService.upload(file, ClientAuthUtil.getUserId());
         clientEmbeddingService.embedDocs(
                 new ChatReq()
+                        .setUserId(ClientAuthUtil.getUserId())
                         .setDocsName(oss.getOriginalFilename())
                         .setKnowledgeId(oss.getId())
                         .setUrl(oss.getUrl()));

+ 7 - 4
langchat-client/src/main/java/cn/tycoding/langchat/client/service/impl/ClientEmbeddingServiceImpl.java

@@ -18,14 +18,16 @@ package cn.tycoding.langchat.client.service.impl;
 
 import cn.tycoding.langchat.client.service.ClientEmbeddingService;
 import cn.tycoding.langchat.common.dto.ChatReq;
+import cn.tycoding.langchat.common.task.TaskManager;
 import cn.tycoding.langchat.core.service.LangEmbeddingService;
 import dev.langchain4j.store.embedding.filter.Filter;
 import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 
+import java.util.concurrent.Executors;
+
 import static cn.tycoding.langchat.core.consts.EmbedConst.KNOWLEDGE;
 import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
 
@@ -38,13 +40,14 @@ import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metad
 @AllArgsConstructor
 public class ClientEmbeddingServiceImpl implements ClientEmbeddingService {
 
-    private final LangEmbeddingService langEmbeddingService;
+    private final LangEmbeddingService embeddingService;
     private final PgVectorEmbeddingStore embeddingStore;
 
-    @Async
     @Override
     public void embedDocs(ChatReq data) {
-        langEmbeddingService.embeddingDocs(data);
+        TaskManager.submitTask(data.getUserId(), Executors.callable(() -> {
+            embeddingService.embeddingDocs(data);
+        }));
     }
 
     @Override

+ 6 - 4
langchat-server/src/main/java/cn/tycoding/langchat/server/endpoint/EmbeddingEndpoint.java

@@ -98,8 +98,9 @@ public class EmbeddingEndpoint {
                 .setType(EmbedConst.ORIGIN_TYPE_UPLOAD)
                 .setKnowledgeId(knowledgeId);
         aigcKnowledgeService.addDocs(data);
-        TaskManager.submitTask(userId,
-                Executors.callable(() -> embeddingService.embedDocsSlice(data, oss.getUrl())));
+        TaskManager.submitTask(userId, Executors.callable(() -> {
+            embeddingService.embedDocsSlice(data, oss.getUrl());
+        }));
         return R.ok();
     }
 
@@ -116,8 +117,9 @@ public class EmbeddingEndpoint {
         if (EmbedConst.ORIGIN_TYPE_UPLOAD.equals(docs.getType())) {
             // clear before re-embed
             embeddingService.clearDocSlices(docsId);
-            TaskManager.submitTask(userId,
-                    Executors.callable(() -> embeddingService.embedDocsSlice(docs, docs.getUrl())));
+            TaskManager.submitTask(userId, Executors.callable(() -> {
+                embeddingService.embedDocsSlice(docs, docs.getUrl());
+            }));
         }
         return R.ok();
     }