Ver Fonte

fix Current limit

gbhblt há 11 meses atrás
pai
commit
c9fd3fb054

+ 4 - 1
langchat-app/src/main/java/cn/tycoding/langchat/app/endpoint/auth/OpenapiAuthAspect.java

@@ -43,8 +43,11 @@ public class OpenapiAuthAspect {
             throw new AuthException(401, "Authentication Token invalid");
         }
 
+
         try {
-            channelStore.isExpired(openapiAuth.value());
+            String value = openapiAuth.value();
+            channelStore.isExpired(value);
+            channelStore.currentLimiting(value);
         } catch (Exception e) {
             throw new ServiceException(e.getMessage());
         }

+ 4 - 1
langchat-app/src/main/java/cn/tycoding/langchat/app/entity/AigcAppApi.java

@@ -17,7 +17,9 @@
 
 package cn.tycoding.langchat.app.entity;
 
+import com.baomidou.mybatisplus.annotation.FieldStrategy;
 import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableField;
 import com.baomidou.mybatisplus.annotation.TableId;
 import lombok.Data;
 import lombok.experimental.Accessors;
@@ -44,7 +46,8 @@ public class AigcAppApi implements Serializable {
 
     private String channel;
     private String apiKey;
-    private Integer reqLimit = 100;
+    @TableField(updateStrategy = FieldStrategy.ALWAYS)
+    private Integer reqLimit;
     private String name;
     private String des;
     private Date expired = null;

+ 3 - 1
langchat-app/src/main/java/cn/tycoding/langchat/app/entity/AigcAppWeb.java

@@ -17,6 +17,7 @@
 
 package cn.tycoding.langchat.app.entity;
 
+import com.baomidou.mybatisplus.annotation.FieldStrategy;
 import com.baomidou.mybatisplus.annotation.IdType;
 import com.baomidou.mybatisplus.annotation.TableField;
 import com.baomidou.mybatisplus.annotation.TableId;
@@ -45,7 +46,8 @@ public class AigcAppWeb implements Serializable {
 
     private String channel;
     private String apiKey;
-    private Integer reqLimit = 100;
+    @TableField(updateStrategy = FieldStrategy.ALWAYS)
+    private Integer reqLimit;
     private String name;
     private String des;
     private Date expired = null;

+ 62 - 4
langchat-app/src/main/java/cn/tycoding/langchat/app/store/AppChannelStore.java

@@ -17,6 +17,7 @@
 package cn.tycoding.langchat.app.store;
 
 import cn.hutool.core.date.DateUtil;
+import cn.hutool.core.util.ObjectUtil;
 import cn.tycoding.langchat.app.consts.AppConst;
 import cn.tycoding.langchat.app.entity.AigcAppApi;
 import cn.tycoding.langchat.app.entity.AigcAppWeb;
@@ -27,12 +28,12 @@ import cn.tycoding.langchat.common.utils.ServletUtil;
 import jakarta.annotation.PostConstruct;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.core.script.DefaultRedisScript;
+import org.springframework.data.redis.core.script.RedisScript;
 import org.springframework.stereotype.Component;
 
-import java.util.Date;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
 
 /**
  * @author tycoding
@@ -48,6 +49,7 @@ public class AppChannelStore {
 
     private final AigcAppApiService appApiService;
     private final AigcAppWebService appWebService;
+    private final RedisTemplate stringRedisTemplate;
 
     public static AigcAppApi getApiChannel() {
         String token = ServletUtil.getAuthorizationToken();
@@ -69,6 +71,62 @@ public class AppChannelStore {
         webs.forEach(web -> WEB_MAP.put(web.getApiKey(), web));
     }
 
+
+    /**
+     * 限流校验
+     *
+     * @param channel
+     */
+    public void currentLimiting(String channel) {
+        try {
+            String token = ServletUtil.getAuthorizationToken();
+            Integer reqLimit = 0;
+            if (AppConst.CHANNEL_API.equals(channel)) {
+                AigcAppApi data = API_MAP.get(token);
+                if (data == null) {
+                    throw new RuntimeException("The ApiKey is empty");
+                }
+                reqLimit = data.getReqLimit();
+            }
+
+            if (AppConst.CHANNEL_WEB.equals(channel)) {
+                AigcAppWeb data = WEB_MAP.get(token);
+                if (data == null) {
+                    throw new RuntimeException("The ApiKey is empty");
+                }
+                reqLimit = data.getReqLimit();
+            }
+            if (ObjectUtil.isEmpty(reqLimit)) {
+                return;
+            }
+
+            long currentTime = System.currentTimeMillis();
+            String lua = "local window_start = ARGV[1] - 60000\n" +
+                    "redis.call('ZREMRANGEBYSCORE', KEYS[1], '-inf', window_start)\n" +
+                    "local current_requests = redis.call('ZCARD', KEYS[1])\n" +
+                    "if current_requests < tonumber(ARGV[2]) then\n" +
+                    "  redis.call('ZADD', KEYS[1], ARGV[1], ARGV[1])\n" +
+                    "  return 1\n" +
+                    "else\n" +
+                    "  return 0\n" +
+                    "end";
+            RedisScript<Long> integerRedisScript = new DefaultRedisScript<>(lua, Long.class);
+            Long result = (Long) stringRedisTemplate.execute(integerRedisScript,
+                    Collections.singletonList(token),
+                    String.valueOf(currentTime),
+                    String.valueOf(reqLimit)
+            );
+            if (result != 1) {
+                throw new ServiceException("The api is restricted");
+            }
+        } catch (ServiceException e) {
+            throw e;
+        } catch (Exception e) {
+            log.error("校验限流出错", e);
+        }
+
+    }
+
     public void isExpired(String channel) {
         String token = ServletUtil.getAuthorizationToken();