Przeglądaj źródła

feat: clear docSlices before re-embed

(cherry picked from commit f491f6cdffaf03664a02120b7bbf00ce1a763ed3)
muzhix 1 rok temu
rodzic
commit
6ae1957a4f

+ 6 - 0
langchat-biz/src/main/java/cn/tycoding/langchat/biz/service/AigcKnowledgeService.java

@@ -21,6 +21,8 @@ import cn.tycoding.langchat.biz.entity.AigcDocsSlice;
 import cn.tycoding.langchat.biz.entity.AigcKnowledge;
 import com.baomidou.mybatisplus.extension.service.IService;
 
+import java.util.List;
+
 /**
  * @author tycoding
  * @since 2024/4/15
@@ -34,5 +36,9 @@ public interface AigcKnowledgeService extends IService<AigcKnowledge> {
     void addDocsSlice(AigcDocsSlice data);
 
     void updateDocsSlice(AigcDocsSlice data);
+
+    List<String> listSliceVectorIdsOfDoc(String docsId);
+
+    void removeSlicesOfDoc(String docsId);
 }
 

+ 26 - 0
langchat-biz/src/main/java/cn/tycoding/langchat/biz/service/impl/AigcKnowledgeServiceImpl.java

@@ -23,16 +23,21 @@ import cn.tycoding.langchat.biz.mapper.AigcDocsMapper;
 import cn.tycoding.langchat.biz.mapper.AigcDocsSliceMapper;
 import cn.tycoding.langchat.biz.mapper.AigcKnowledgeMapper;
 import cn.tycoding.langchat.biz.service.AigcKnowledgeService;
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
 import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
 
 import java.util.Date;
+import java.util.List;
 
 /**
  * @author tycoding
  * @since 2024/4/15
  */
+@Slf4j
 @Service
 @RequiredArgsConstructor
 public class AigcKnowledgeServiceImpl extends ServiceImpl<AigcKnowledgeMapper, AigcKnowledge> implements AigcKnowledgeService {
@@ -64,5 +69,26 @@ public class AigcKnowledgeServiceImpl extends ServiceImpl<AigcKnowledgeMapper, A
     public void updateDocsSlice(AigcDocsSlice data) {
         aigcDocsSliceMapper.updateById(data);
     }
+
+    @Override
+    public List<String> listSliceVectorIdsOfDoc(String docsId) {
+        LambdaQueryWrapper<AigcDocsSlice> selectWrapper = Wrappers.<AigcDocsSlice>lambdaQuery()
+                .select(AigcDocsSlice::getVectorId)
+                .eq(AigcDocsSlice::getDocsId, docsId);
+        List<String> vectorIds = aigcDocsSliceMapper.selectList(selectWrapper)
+                .stream()
+                .map(AigcDocsSlice::getVectorId)
+                .toList();
+        log.debug("slices of doc: [{}], count: [{}]", docsId, vectorIds.size());
+        return vectorIds;
+    }
+
+    @Override
+    public void removeSlicesOfDoc(String docsId) {
+        LambdaQueryWrapper<AigcDocsSlice> deleteWrapper = Wrappers.<AigcDocsSlice>lambdaQuery()
+                .eq(AigcDocsSlice::getDocsId, docsId);
+        int count = aigcDocsSliceMapper.delete(deleteWrapper);
+        log.debug("remove all slices of doc: [{}], count: [{}]", docsId, count);
+    }
 }
 

+ 2 - 0
langchat-server/src/main/java/cn/tycoding/langchat/server/endpoint/EmbeddingEndpoint.java

@@ -110,6 +110,8 @@ public class EmbeddingEndpoint {
             text(docs);
         }
         if (EmbedConst.ORIGIN_TYPE_UPLOAD.equals(docs.getType())) {
+            // clear before re-embed
+            embeddingService.clearDocSlicesOfDoc(docsId);
             embeddingService.embedDocsSlice(docs, docs.getUrl());
         }
         return R.ok();

+ 8 - 0
langchat-server/src/main/java/cn/tycoding/langchat/server/service/EmbeddingService.java

@@ -27,6 +27,14 @@ import java.util.Map;
  */
 public interface EmbeddingService {
 
+    /**
+     * 删除文档已有切片
+     *
+     * @param docsId 文档id
+     * @return true:已删除切片;false:切片删除失败
+     */
+    void clearDocSlicesOfDoc(String docsId);
+
     void embedDocsSlice(AigcDocs data, String url);
 
     List<Map<String, Object>> search(AigcDocs data);

+ 14 - 0
langchat-server/src/main/java/cn/tycoding/langchat/server/service/impl/EmbeddingServiceImpl.java

@@ -36,6 +36,7 @@ import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -58,6 +59,19 @@ public class EmbeddingServiceImpl implements EmbeddingService {
     private final AigcKnowledgeService aigcKnowledgeService;
     private final PgVectorEmbeddingStore embeddingStore;
 
+    @Override
+    @Transactional
+    public void clearDocSlicesOfDoc(String docsId) {
+        if (StrUtil.isBlank(docsId)) {
+            return;
+        }
+        // remove from embedding store
+        List<String> vectorIds = aigcKnowledgeService.listSliceVectorIdsOfDoc(docsId);
+        embeddingStore.removeAll(vectorIds);
+        // remove from docSlice
+        aigcKnowledgeService.removeSlicesOfDoc(docsId);
+    }
+
     @Async
     @Override
     public void embedDocsSlice(AigcDocs data, String url) {