瀏覽代碼

add knowledge docs & slice

tycoding 1 年之前
父節點
當前提交
63af454dd4

+ 10 - 1
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/endpoint/EmbeddingEndpoint.java

@@ -2,8 +2,10 @@ package cn.tycoding.langchat.aigc.endpoint;
 
 import cn.hutool.core.util.StrUtil;
 import cn.tycoding.langchat.aigc.entity.AigcDocs;
+import cn.tycoding.langchat.aigc.entity.AigcDocsSlice;
 import cn.tycoding.langchat.aigc.service.AigcKnowledgeService;
 import cn.tycoding.langchat.common.dto.DocR;
+import cn.tycoding.langchat.common.dto.EmbeddingR;
 import cn.tycoding.langchat.common.exception.ServiceException;
 import cn.tycoding.langchat.core.service.LangDocService;
 import lombok.AllArgsConstructor;
@@ -30,8 +32,15 @@ public class EmbeddingEndpoint {
             throw new ServiceException("文档内容不能为空");
         }
         aigcKnowledgeService.addDocs(data);
-        langDocService.embeddingText(new DocR().setMessage(data.getContent())
+        EmbeddingR embeddingR = langDocService.embeddingText(new DocR().setMessage(data.getContent())
                 .setId(data.getId())
                 .setKnowledgeId(data.getKnowledgeId()));
+        aigcKnowledgeService.addDocsSlice(new AigcDocsSlice()
+                .setKnowledgeId(data.getKnowledgeId())
+                .setDocsId(data.getId())
+                .setVectorId(embeddingR.getVectorId())
+                .setName(data.getName())
+                .setContent(embeddingR.getText())
+        );
     }
 }

+ 15 - 2
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/entity/AigcDocsSlice.java

@@ -3,14 +3,17 @@ package cn.tycoding.langchat.aigc.entity;
 import com.baomidou.mybatisplus.annotation.IdType;
 import com.baomidou.mybatisplus.annotation.TableId;
 import lombok.Data;
+import lombok.experimental.Accessors;
 
 import java.io.Serializable;
+import java.util.Date;
 
 /**
  * @author tycoding
  * @since 2024/4/15
  */
 @Data
+@Accessors(chain = true)
 public class AigcDocsSlice implements Serializable {
     private static final long serialVersionUID = -3093489071059867065L;
 
@@ -20,6 +23,11 @@ public class AigcDocsSlice implements Serializable {
     @TableId(type = IdType.ASSIGN_UUID)
     private String id;
 
+    /**
+     * 向量库的ID
+     */
+    private String vectorId;
+
     /**
      * 文档ID
      */
@@ -31,10 +39,15 @@ public class AigcDocsSlice implements Serializable {
     private String knowledgeId;
 
     /**
-     * 文名称
+     * 文名称
      */
     private String name;
 
+    /**
+     * 切片内容
+     */
+    private String content;
+
     /**
      * 字符数量
      */
@@ -48,6 +61,6 @@ public class AigcDocsSlice implements Serializable {
     /**
      * 创建时间
      */
-    private String createTime;
+    private Date createTime;
 }
 

+ 9 - 0
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/service/AigcKnowledgeService.java

@@ -1,6 +1,7 @@
 package cn.tycoding.langchat.aigc.service;
 
 import cn.tycoding.langchat.aigc.entity.AigcDocs;
+import cn.tycoding.langchat.aigc.entity.AigcDocsSlice;
 import cn.tycoding.langchat.aigc.entity.AigcKnowledge;
 import com.baomidou.mybatisplus.extension.service.IService;
 
@@ -10,6 +11,14 @@ import com.baomidou.mybatisplus.extension.service.IService;
  */
 public interface AigcKnowledgeService extends IService<AigcKnowledge> {
 
+    /**
+     * 添加文档数据
+     */
     void addDocs(AigcDocs data);
+
+    /**
+     * 在指定文档中添加Embedding后的切片数据
+     */
+    void addDocsSlice(AigcDocsSlice data);
 }
 

+ 12 - 0
langchat-aigc/src/main/java/cn/tycoding/langchat/aigc/service/impl/AigcKnowledgeServiceImpl.java

@@ -1,8 +1,10 @@
 package cn.tycoding.langchat.aigc.service.impl;
 
 import cn.tycoding.langchat.aigc.entity.AigcDocs;
+import cn.tycoding.langchat.aigc.entity.AigcDocsSlice;
 import cn.tycoding.langchat.aigc.entity.AigcKnowledge;
 import cn.tycoding.langchat.aigc.mapper.AigcDocsMapper;
+import cn.tycoding.langchat.aigc.mapper.AigcDocsSliceMapper;
 import cn.tycoding.langchat.aigc.mapper.AigcKnowledgeMapper;
 import cn.tycoding.langchat.aigc.service.AigcKnowledgeService;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
@@ -20,6 +22,7 @@ import java.util.Date;
 public class AigcKnowledgeServiceImpl extends ServiceImpl<AigcKnowledgeMapper, AigcKnowledge> implements AigcKnowledgeService {
 
     private final AigcDocsMapper aigcDocsMapper;
+    private final AigcDocsSliceMapper aigcDocsSliceMapper;
 
     @Override
     public void addDocs(AigcDocs data) {
@@ -29,5 +32,14 @@ public class AigcKnowledgeServiceImpl extends ServiceImpl<AigcKnowledgeMapper, A
         data.setSliceStatus(false);
         aigcDocsMapper.insert(data);
     }
+
+    @Override
+    public void addDocsSlice(AigcDocsSlice data) {
+        data.setCreateTime(new Date())
+                .setWordNum(data.getContent().length())
+                .setStatus(true)
+        ;
+        aigcDocsSliceMapper.insert(data);
+    }
 }
 

+ 33 - 0
langchat-common/src/main/java/cn/tycoding/langchat/common/dto/EmbeddingR.java

@@ -0,0 +1,33 @@
+package cn.tycoding.langchat.common.dto;
+
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+/**
+ * @author tycoding
+ * @since 2024/4/26
+ */
+@Data
+@Accessors(chain = true)
+public class EmbeddingR {
+
+    /**
+     * 写入到vector store的ID
+     */
+    private String vectorId;
+
+    /**
+     * 文档ID
+     */
+    private String docsId;
+
+    /**
+     * 知识库ID
+     */
+    private String knowledgeId;
+
+    /**
+     * Embedding后切片的文本
+     */
+    private String text;
+}

+ 20 - 0
langchat-core/src/main/java/cn/tycoding/langchat/core/component/VectorStoreComponent.java

@@ -2,7 +2,10 @@ package cn.tycoding.langchat.core.component;
 
 import cn.tycoding.langchat.core.properties.LangChatProps;
 import cn.tycoding.langchat.core.properties.vectorstore.MilvusProps;
+import dev.langchain4j.internal.Utils;
 import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.param.ConnectParam;
 import lombok.AllArgsConstructor;
 import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
 import org.springframework.context.annotation.Bean;
@@ -38,4 +41,21 @@ public class VectorStoreComponent {
                 .build();
     }
 
+    @Bean
+    @ConditionalOnProperty(value = "langchat.vectorstore.milvus.host", matchIfMissing = false)
+    public MilvusServiceClient milvusServiceClient() {
+        MilvusProps prop = props.getVectorstore().getMilvus();
+        ConnectParam.Builder connectBuilder = ConnectParam.newBuilder()
+                .withHost((String) Utils.getOrDefault(prop.getHost(), "localhost"))
+                .withPort((Integer) Utils.getOrDefault(prop.getPort(), 19530))
+                .withUri(prop.getUri())
+                .withToken(prop.getToken())
+                .withAuthorization(prop.getUsername(), prop.getPassword());
+        if (prop.getCollectionName() != null) {
+            connectBuilder.withDatabaseName(prop.getDatabaseName());
+        }
+
+        return new MilvusServiceClient(connectBuilder.build());
+    }
+
 }

+ 2 - 1
langchat-core/src/main/java/cn/tycoding/langchat/core/service/LangDocService.java

@@ -1,6 +1,7 @@
 package cn.tycoding.langchat.core.service;
 
 import cn.tycoding.langchat.common.dto.DocR;
+import cn.tycoding.langchat.common.dto.EmbeddingR;
 import dev.langchain4j.service.TokenStream;
 
 /**
@@ -12,7 +13,7 @@ public interface LangDocService {
     /**
      * 解析文本向量
      */
-    void embeddingText(DocR req);
+    EmbeddingR embeddingText(DocR req);
 
     /**
      * 解析文本文件向量

+ 5 - 2
langchat-core/src/main/java/cn/tycoding/langchat/core/service/impl/LangDocServiceImpl.java

@@ -1,6 +1,7 @@
 package cn.tycoding.langchat.core.service.impl;
 
 import cn.tycoding.langchat.common.dto.DocR;
+import cn.tycoding.langchat.common.dto.EmbeddingR;
 import cn.tycoding.langchat.core.enums.ModelConst;
 import cn.tycoding.langchat.core.provider.EmbedProvider;
 import cn.tycoding.langchat.core.provider.ModelProvider;
@@ -49,12 +50,14 @@ public class LangDocServiceImpl implements LangDocService {
     private final MilvusEmbeddingStore milvusEmbeddingStore;
 
     @Override
-    public void embeddingText(DocR req) {
+    public EmbeddingR embeddingText(DocR req) {
         TextSegment segment = TextSegment.from(req.getMessage(),
                 metadata("knowledgeId", req.getKnowledgeId()));
         EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
         Embedding embedding = embeddingModel.embed(segment).content();
-        milvusEmbeddingStore.add(embedding, segment);
+
+        String id = milvusEmbeddingStore.add(embedding, segment);
+        return new EmbeddingR().setVectorId(id).setText(segment.text());
     }
 
     @Override

+ 36 - 7
langchat-server/src/test/java/cn/tycoding/langchat/AppTest.java

@@ -1,10 +1,16 @@
 package cn.tycoding.langchat;
 
-import dev.langchain4j.data.segment.TextSegment;
-import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
-import dev.langchain4j.store.embedding.EmbeddingSearchResult;
 import dev.langchain4j.store.embedding.filter.Filter;
-import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.grpc.DescribeCollectionResponse;
+import io.milvus.grpc.GetCollectionStatisticsResponse;
+import io.milvus.grpc.SearchResults;
+import io.milvus.grpc.ShowCollectionsResponse;
+import io.milvus.param.R;
+import io.milvus.param.collection.DescribeCollectionParam;
+import io.milvus.param.collection.GetCollectionStatisticsParam;
+import io.milvus.param.collection.ShowCollectionsParam;
+import io.milvus.param.dml.SearchParam;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.springframework.beans.factory.annotation.Autowired;
@@ -22,13 +28,36 @@ import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metad
 public class AppTest {
 
     @Autowired
-    private MilvusEmbeddingStore milvusEmbeddingStore;
+    private MilvusServiceClient milvusClient;
 
     @Test
     public void t1() {
         Filter filter = metadataKey("knowledgeId").isEqualTo("f228b6c9-bce2-4fd0-239s8-fbc3b893e36e");
-        EmbeddingSearchResult<TextSegment> search =
-                milvusEmbeddingStore.search(EmbeddingSearchRequest.builder().filter(filter).build());
+
+        R<ShowCollectionsResponse> respShowCollections = milvusClient.showCollections(
+                ShowCollectionsParam.newBuilder().build()
+        );
+        R<DescribeCollectionResponse> respDescribeCollection = milvusClient.describeCollection(
+                // Return the name and schema of the collection.
+                DescribeCollectionParam.newBuilder()
+                        .withCollectionName("test3")
+                        .build()
+        );
+        R<GetCollectionStatisticsResponse> respCollectionStatistics = milvusClient.getCollectionStatistics(
+                // Return the statistics information of the collection.
+                GetCollectionStatisticsParam.newBuilder()
+                        .withCollectionName("test3")
+                        .build()
+        );
+
+
+        SearchParam searchParam = SearchParam.newBuilder()
+                .withCollectionName("test3")
+                .withVectorFieldName("text")
+                .withTopK(1000)
+                .build();
+        R<SearchResults> search = milvusClient.search(searchParam);
+
         System.out.println("-----");
     }
 }