|
@@ -1,9 +1,24 @@
|
|
|
package com.pavis.admin.aigc.service.impl;
|
|
|
|
|
|
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
|
|
+import com.pavis.admin.aigc.mapper.DocMapper;
|
|
|
+import com.pavis.admin.aigc.model.entity.DocDO;
|
|
|
+import com.pavis.admin.aigc.model.req.SimilaritySearchReq;
|
|
|
+import com.pavis.admin.aigc.model.resp.SimilaritySearchResp;
|
|
|
+import dev.langchain4j.data.embedding.Embedding;
|
|
|
+import dev.langchain4j.data.segment.TextSegment;
|
|
|
+import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
|
+import dev.langchain4j.store.embedding.CosineSimilarity;
|
|
|
+import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
|
+import dev.langchain4j.store.embedding.RelevanceScore;
|
|
|
+import jakarta.annotation.Resource;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.springframework.data.redis.core.RedisTemplate;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
+import top.continew.starter.cache.redisson.util.RedisUtils;
|
|
|
import top.continew.starter.extension.crud.service.BaseServiceImpl;
|
|
|
import com.pavis.admin.aigc.mapper.DocChunkMapper;
|
|
|
import com.pavis.admin.aigc.model.entity.DocChunkDO;
|
|
@@ -13,6 +28,14 @@ import com.pavis.admin.aigc.model.resp.DocChunkDetailResp;
|
|
|
import com.pavis.admin.aigc.model.resp.DocChunkResp;
|
|
|
import com.pavis.admin.aigc.service.DocChunkService;
|
|
|
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Comparator;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.PriorityQueue;
|
|
|
+import java.util.stream.Collectors;
|
|
|
+
|
|
|
/**
|
|
|
* 文档切片业务实现
|
|
|
*
|
|
@@ -21,4 +44,96 @@ import com.pavis.admin.aigc.service.DocChunkService;
|
|
|
*/
|
|
|
@Service
|
|
|
@RequiredArgsConstructor
|
|
|
-public class DocChunkServiceImpl extends BaseServiceImpl<DocChunkMapper, DocChunkDO, DocChunkResp, DocChunkDetailResp, DocChunkQuery, DocChunkReq> implements DocChunkService {}
|
|
|
+@Slf4j
|
|
|
+public class DocChunkServiceImpl extends BaseServiceImpl<DocChunkMapper, DocChunkDO, DocChunkResp, DocChunkDetailResp, DocChunkQuery, DocChunkReq> implements DocChunkService {
|
|
|
+ // @Resource
|
|
|
+ // private EmbeddingModel embeddingModel;
|
|
|
+
|
|
|
+ @Resource
|
|
|
+ private DocMapper docMapper;
|
|
|
+
|
|
|
+ @Resource
|
|
|
+ private DocServiceImpl docService;
|
|
|
+
|
|
|
+ public List<SimilaritySearchResp> SimilaritySearch(SimilaritySearchReq req) {
|
|
|
+ EmbeddingModel embeddingModel = docService.getEmbeddingModel(req.getKnowledgeId());
|
|
|
+ // 向量化搜索文本
|
|
|
+ Embedding keywordEmbedding = embeddingModel.embed(req.getKeyword()).content();
|
|
|
+
|
|
|
+ // 从数据库查询嵌入数据
|
|
|
+ List<DocChunkDO> embeddings;
|
|
|
+ LambdaQueryWrapper<DocChunkDO> queryWrapper = new LambdaQueryWrapper<>();
|
|
|
+ if (req.getDocumentId() != null) {
|
|
|
+ queryWrapper.eq(DocChunkDO::getDocId, req.getDocumentId());
|
|
|
+ } else if (req.getKnowledgeId() != null) {
|
|
|
+ LambdaQueryWrapper<DocDO> docQueryWrapper = new LambdaQueryWrapper<>();
|
|
|
+ docQueryWrapper.eq(DocDO::getKnowledgeId, req.getKnowledgeId());
|
|
|
+ List<DocDO> documents = docMapper.selectList(docQueryWrapper);
|
|
|
+ List<Long> documentIds = documents.stream().map(DocDO::getId).collect(Collectors.toList());
|
|
|
+ if (!documentIds.isEmpty()) {
|
|
|
+ queryWrapper.in(DocChunkDO::getDocId, documentIds);
|
|
|
+ } else {
|
|
|
+ queryWrapper.eq(DocChunkDO::getId, -1L); // 无匹配的文档 ID 时查询无结果
|
|
|
+ }
|
|
|
+ }
|
|
|
+ embeddings = baseMapper.selectList(queryWrapper);
|
|
|
+
|
|
|
+ Map<Long, DocDO> documentMap = new HashMap<>();
|
|
|
+ if (!embeddings.isEmpty()) {
|
|
|
+ List<Long> documentIds = embeddings.stream()
|
|
|
+ .map(DocChunkDO::getDocId)
|
|
|
+ .distinct()
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ if (!documentIds.isEmpty()) {
|
|
|
+ List<DocDO> documents = docMapper.selectByIds(documentIds);
|
|
|
+ documentMap = documents.stream().collect(Collectors.toMap(DocDO::getId, document -> document));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 比较器
|
|
|
+ Comparator<SimilaritySearchResp> comparator = Comparator.comparingDouble(SimilaritySearchResp::getScore);
|
|
|
+ PriorityQueue<SimilaritySearchResp> matches = new PriorityQueue<>(comparator);
|
|
|
+
|
|
|
+ // 初步构建查询结果
|
|
|
+ for (DocChunkDO embedding : embeddings) {
|
|
|
+ // log.info("向量库存储内容:{}",RedisUtils.get("embedding:" + embedding.getEmbedStoreId()));
|
|
|
+ // 计算相似度
|
|
|
+ float[] vector = stringToFloatArray(embedding.getVector());
|
|
|
+ double cosineSimilarity = CosineSimilarity.between(new Embedding(vector), keywordEmbedding);
|
|
|
+ // double cosineSimilarity = CosineSimilarity.between(new Embedding(embedding.getVector()), keywordEmbedding);
|
|
|
+ double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
|
|
+
|
|
|
+ DocDO document = documentMap.get(embedding.getDocId());
|
|
|
+ if (score >= req.getMinScore() && document != null) {
|
|
|
+ SimilaritySearchResp resp = new SimilaritySearchResp(embedding.getId(), score, embedding
|
|
|
+ .getContent(), document);
|
|
|
+ if (matches.size() < req.getMaxResults()) {
|
|
|
+ matches.add(resp);
|
|
|
+ } else if (matches.peek() != null && score > matches.peek().getScore()) {
|
|
|
+ matches.poll();
|
|
|
+ matches.add(resp);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 进一步构建查询结果
|
|
|
+ List<SimilaritySearchResp> resps = new ArrayList<>(matches);
|
|
|
+ resps.sort(comparator.reversed());
|
|
|
+ return resps;
|
|
|
+ }
|
|
|
+
|
|
|
+ public static float[] stringToFloatArray(String input) {
|
|
|
+ String[] parts = input.split(",");
|
|
|
+ float[] result = new float[parts.length];
|
|
|
+
|
|
|
+ try {
|
|
|
+ for (int i = 0; i < parts.length; i++) {
|
|
|
+ result[i] = Float.parseFloat(parts[i].trim());
|
|
|
+ }
|
|
|
+ } catch (NumberFormatException e) {
|
|
|
+ System.err.println("输入的字符串格式不正确: " + input);
|
|
|
+ throw e;
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+}
|