EmbeddingEndpoint.java 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. /*
  2. * Copyright (c) 2024 LangChat. TyCoding All Rights Reserved.
  3. *
  4. * Licensed under the GNU Affero General Public License, Version 3 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * https://www.gnu.org/licenses/agpl-3.0.html
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package cn.tycoding.langchat.server.endpoint;
  17. import cn.dev33.satoken.annotation.SaCheckPermission;
  18. import cn.hutool.core.util.StrUtil;
  19. import cn.tycoding.langchat.biz.entity.AigcDocs;
  20. import cn.tycoding.langchat.biz.entity.AigcDocsSlice;
  21. import cn.tycoding.langchat.biz.entity.AigcOss;
  22. import cn.tycoding.langchat.biz.mapper.AigcDocsMapper;
  23. import cn.tycoding.langchat.biz.service.AigcKnowledgeService;
  24. import cn.tycoding.langchat.biz.service.AigcOssService;
  25. import cn.tycoding.langchat.common.dto.ChatReq;
  26. import cn.tycoding.langchat.common.dto.EmbeddingR;
  27. import cn.tycoding.langchat.common.exception.ServiceException;
  28. import cn.tycoding.langchat.common.task.TaskManager;
  29. import cn.tycoding.langchat.common.utils.R;
  30. import cn.tycoding.langchat.core.consts.EmbedConst;
  31. import cn.tycoding.langchat.core.service.LangEmbeddingService;
  32. import cn.tycoding.langchat.server.service.EmbeddingService;
  33. import cn.tycoding.langchat.upms.utils.AuthUtil;
  34. import lombok.AllArgsConstructor;
  35. import lombok.extern.slf4j.Slf4j;
  36. import org.springframework.web.bind.annotation.*;
  37. import org.springframework.web.multipart.MultipartFile;
  38. import java.util.concurrent.Executors;
  39. /**
  40. * @author tycoding
  41. * @since 2024/4/25
  42. */
  43. @Slf4j
  44. @RestController
  45. @AllArgsConstructor
  46. @RequestMapping("/aigc/embedding")
  47. public class EmbeddingEndpoint {
  48. private final LangEmbeddingService langEmbeddingService;
  49. private final AigcKnowledgeService aigcKnowledgeService;
  50. private final AigcDocsMapper aigcDocsMapper;
  51. private final AigcOssService aigcOssService;
  52. private final EmbeddingService embeddingService;
  53. @PostMapping("/text")
  54. @SaCheckPermission("aigc:embedding:text")
  55. public R text(@RequestBody AigcDocs data) {
  56. if (StrUtil.isBlankIfStr(data.getContent())) {
  57. throw new ServiceException("文档内容不能为空");
  58. }
  59. data.setType(EmbedConst.ORIGIN_TYPE_INPUT).setSliceStatus(false);
  60. if (StrUtil.isBlank(data.getId())) {
  61. aigcKnowledgeService.addDocs(data);
  62. }
  63. EmbeddingR embeddingR = langEmbeddingService.embeddingText(
  64. new ChatReq().setMessage(data.getContent())
  65. .setDocsName(data.getType())
  66. .setDocsId(data.getId())
  67. .setKnowledgeId(data.getKnowledgeId()));
  68. aigcKnowledgeService.addDocsSlice(new AigcDocsSlice()
  69. .setKnowledgeId(data.getKnowledgeId())
  70. .setDocsId(data.getId())
  71. .setVectorId(embeddingR.getVectorId())
  72. .setName(data.getName())
  73. .setContent(embeddingR.getText())
  74. );
  75. aigcKnowledgeService.updateDocs(new AigcDocs().setId(data.getId()).setSliceStatus(true).setSliceNum(1));
  76. return R.ok();
  77. }
  78. @PostMapping("/docs/{knowledgeId}")
  79. @SaCheckPermission("aigc:embedding:docs")
  80. public R docs(MultipartFile file, @PathVariable String knowledgeId) {
  81. String userId = String.valueOf(AuthUtil.getUserId());
  82. AigcOss oss = aigcOssService.upload(file, userId);
  83. AigcDocs data = new AigcDocs()
  84. .setName(oss.getOriginalFilename())
  85. .setSliceStatus(false)
  86. .setUrl(oss.getUrl())
  87. .setSize(file.getSize())
  88. .setType(EmbedConst.ORIGIN_TYPE_UPLOAD)
  89. .setKnowledgeId(knowledgeId);
  90. aigcKnowledgeService.addDocs(data);
  91. TaskManager.submitTask(userId,
  92. Executors.callable(() -> embeddingService.embedDocsSlice(data, oss.getUrl())));
  93. return R.ok();
  94. }
  95. @GetMapping("/re-embed/{docsId}")
  96. public R reEmbed(@PathVariable String docsId) {
  97. String userId = String.valueOf(AuthUtil.getUserId());
  98. AigcDocs docs = aigcDocsMapper.selectById(docsId);
  99. if (docs == null) {
  100. throw new ServiceException("没有查询到文档数据");
  101. }
  102. if (EmbedConst.ORIGIN_TYPE_INPUT.equals(docs.getType())) {
  103. text(docs);
  104. }
  105. if (EmbedConst.ORIGIN_TYPE_UPLOAD.equals(docs.getType())) {
  106. TaskManager.submitTask(userId,
  107. Executors.callable(() -> embeddingService.embedDocsSlice(docs, docs.getUrl())));
  108. }
  109. return R.ok();
  110. }
  111. @PostMapping("/search")
  112. public R search(@RequestBody AigcDocs data) {
  113. return R.ok(embeddingService.search(data));
  114. }
  115. }