Browse Source

修复部分SQL错误

tycoding 1 year ago
parent
commit
f1ee2bda24

+ 10 - 14
langchat-app/src/main/java/cn/tycoding/langchat/app/endpoint/AppApiChatEndpoint.java

@@ -16,10 +16,10 @@
 
 package cn.tycoding.langchat.app.endpoint;
 
-import cn.hutool.core.lang.Dict;
+import cn.tycoding.langchat.app.endpoint.auth.CompletionReq;
+import cn.tycoding.langchat.app.endpoint.auth.OpenapiAuth;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.http.MediaType;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestBody;
 import org.springframework.web.bind.annotation.RequestMapping;
@@ -35,24 +35,20 @@ import java.io.IOException;
 @Slf4j
 @RestController
 @RequiredArgsConstructor
-@RequestMapping("/langchat/openapi/v1")
+@RequestMapping("/v1")
 public class AppApiChatEndpoint {
 
-    @PostMapping("/test")
-    public Object test(@RequestBody Object obj) {
-        log.info("x: {}", obj);
-        return Dict.create().set("message", "你好呀").set("threadId", "111");
-    }
-
-    @PostMapping("/test2")
-    public Object test2(@RequestBody Object obj) throws InterruptedException, IOException {
-        log.info("Received: {}", obj);
+    @OpenapiAuth
+    @PostMapping("/chat/completions")
+    public Object test2(@RequestBody CompletionReq req) throws InterruptedException, IOException {
+        log.info("x: {}", req);
         ResponseBodyEmitter emitter = new ResponseBodyEmitter();
 
         new Thread(() -> {
             try {
-                for (int i = 0; i < 10; i++) {
-                    emitter.send("Data: " + i + "\n", MediaType.TEXT_PLAIN);
+                for (int i = 0; i < 5; i++) {
+//                    new ChatReq().set
+//                    emitter.send(JSON.toJSONString(res));
                     Thread.sleep(1000);
                 }
                 emitter.complete();

+ 51 - 0
langchat-app/src/main/java/cn/tycoding/langchat/app/endpoint/auth/CompletionReq.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2024 LangChat. TyCoding All Rights Reserved.
+ *
+ * Licensed under the GNU Affero General Public License, Version 3 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     https://www.gnu.org/licenses/agpl-3.0.html
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package cn.tycoding.langchat.app.endpoint.auth;
+
+import lombok.Builder;
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * @author tycoding
+ * @since 2024/7/30
+ */
+@Data
+@Builder
+public class CompletionReq {
+
+    private final String model;
+    private final List<Message> messages;
+    private final Double temperature;
+    private final Double topP;
+    private final Integer n;
+    private final Boolean stream;
+    private final List<String> stop;
+    private final Integer maxTokens;
+    private final Double presencePenalty;
+    private final Double frequencyPenalty;
+    private final String user;
+    private final Integer seed;
+
+    @Data
+    @Builder
+    static class Message {
+        String role;
+        String content;
+    }
+}

+ 11 - 6
langchat-app/src/main/java/cn/tycoding/langchat/app/endpoint/auth/OpenapiAuthAspect.java

@@ -16,25 +16,30 @@
 
 package cn.tycoding.langchat.app.endpoint.auth;
 
-import cn.dev33.satoken.exception.NotPermissionException;
-import cn.tycoding.langchat.biz.entity.AigcUser;
-import cn.tycoding.langchat.biz.utils.ClientAuthUtil;
+import cn.tycoding.langchat.common.exception.AuthException;
+import cn.tycoding.langchat.common.utils.ServletUtil;
+import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.aspectj.lang.ProceedingJoinPoint;
 import org.aspectj.lang.annotation.Around;
 import org.aspectj.lang.annotation.Aspect;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.data.redis.core.StringRedisTemplate;
 
 @Slf4j
 @Aspect
 @Configuration
+@AllArgsConstructor
 public class OpenapiAuthAspect {
 
+    private StringRedisTemplate redisTemplate;
+
     @Around("@annotation(openapiAuth)")
     public Object around(ProceedingJoinPoint point, OpenapiAuth openapiAuth) throws Throwable {
-        AigcUser userInfo = ClientAuthUtil.getUserInfo();
-        if (userInfo.getIsPerms() == null || !userInfo.getIsPerms()) {
-            throw new NotPermissionException("当前账号没有操作权限,请联系管理员");
+        String authorization = ServletUtil.getAuthorizationToken();
+
+        if (authorization == null) {
+            throw new AuthException("Authentication Token invalid");
         }
         return point.proceed();
     }

+ 9 - 9
langchat-biz/src/main/java/cn/tycoding/langchat/biz/mapper/AigcMessageMapper.java

@@ -58,8 +58,8 @@ public interface AigcMessageMapper extends BaseMapper<AigcMessage> {
 
     @Select("""
         SELECT
-            DATE_FORMAT(create_time, '%Y-%m') as month,
-            COUNT(*) as count
+            COALESCE(DATE_FORMAT(create_time, '%Y-%m'), 0) as month,
+            COALESCE(COUNT(*), 0) as count
         FROM
             aigc_message
         WHERE
@@ -74,8 +74,8 @@ public interface AigcMessageMapper extends BaseMapper<AigcMessage> {
 
     @Select("""
         SELECT
-            DATE_FORMAT(create_time, '%Y-%m') as month,
-            SUM(tokens) as count
+            COALESCE(DATE_FORMAT(create_time, '%Y-%m'), 0) as month,
+            COALESCE(SUM(tokens), 0) as count
         FROM
             aigc_message
         WHERE
@@ -90,8 +90,8 @@ public interface AigcMessageMapper extends BaseMapper<AigcMessage> {
 
     @Select("""
         SELECT
-            COUNT(*) AS totalReq,
-            SUM( CASE WHEN DATE ( create_time ) = CURDATE() THEN 1 ELSE 0 END ) AS curReq
+            COALESCE(COUNT(*), 0) AS totalReq,
+            COALESCE(SUM( CASE WHEN DATE ( create_time ) = CURDATE() THEN 1 ELSE 0 END ), 0) AS curReq
         FROM
             aigc_message
         WHERE
@@ -101,10 +101,10 @@ public interface AigcMessageMapper extends BaseMapper<AigcMessage> {
 
     @Select("""
         SELECT
-            SUM( tokens ) AS totalToken,
-            SUM( CASE WHEN DATE ( create_time ) = CURDATE() THEN tokens ELSE 0 END ) AS curToken
+            COALESCE(SUM(tokens), 0) AS totalToken,
+            COALESCE(SUM(CASE WHEN DATE(create_time) = CURDATE() THEN tokens ELSE 0 END), 0) AS curToken
         FROM
-            aigc_message
+            aigc_message;
     """)
     Dict getTotalSum();
 }

+ 15 - 0
langchat-common/src/main/java/cn/tycoding/langchat/common/utils/ServletUtil.java

@@ -54,6 +54,21 @@ public class ServletUtil {
         return null;
     }
 
+    public static String getAuthorizationToken() {
+        String token = getRequest().getHeader("Authorization");
+        if (token != null && token.toLowerCase().startsWith("bearer")) {
+            return token.replace("bearer", "").trim();
+        }
+        return null;
+    }
+
+    public static String getToken(String token) {
+        if (token != null && token.toLowerCase().startsWith("bearer")) {
+            return token.replace("bearer", "").trim();
+        }
+        return token;
+    }
+
     public static String getIpAddr() {
         HttpServletRequest request = getRequest();
         if (request == null) {

+ 0 - 12
langchat-core/src/main/java/cn/tycoding/langchat/core/provider/ProviderInitialize.java

@@ -170,18 +170,6 @@ public class ProviderInitialize implements ApplicationContextAware {
                 contextHolder.registerBean(model.getId(), build);
             }
 
-            if (ProviderEnum.ALIBABA.getModel().equals(provider)) {
-                QwenStreamingChatModel build = QwenStreamingChatModel
-                        .builder()
-                        .apiKey(model.getApiKey())
-                        .modelName(model.getModel())
-                        .maxTokens(model.getResponseLimit())
-                        .temperature(Float.parseFloat(model.getTemperature().toString()))
-                        .topP(model.getTopP())
-                        .build();
-                contextHolder.registerBean(model.getId(), build);
-            }
-
             if (ProviderEnum.ZHIPU.getModel().equals(provider)) {
                 ZhipuAiStreamingChatModel build = ZhipuAiStreamingChatModel
                         .builder()