目标】通过 **自定义注解 + lua 脚本 + 多种维度** 实现接口限流。

自定义限流注解

  1. @Target(ElementType.METHOD)
  2. @Retention(RetentionPolicy.RUNTIME)
  3. @Documented
  4. @Inherited
  5. public @interface MethodRateLimit {
  6. /**
  7. * @return CheckTypeEnum 限流类型。默认值:ALL。可选值:ALL,IP,USER,CUSTOM
  8. */
  9. CheckTypeEnum checkType() default CheckTypeEnum.ALL;
  10. /**
  11. * @return 限流次数。默认值60
  12. */
  13. long limit() default 60;
  14. /**
  15. * @return 限流时间间隔,以秒为单位。默认值60
  16. */
  17. long refreshInterval() default 60;
  18. /**
  19. * 黑名单时间, 单位分钟
  20. */
  21. long blackExpire() default 10;
  22. }

CheckTypeEnum 是限流策略的枚举类:

  1. public enum CheckTypeEnum {
  2. /**
  3. * 所有请求统一限流。例:此方法1分钟只允许访问n次
  4. */
  5. ALL,
  6. /**
  7. * 根据IP限流。例:此方法1分钟只允许此IP访问n次
  8. */
  9. IP,
  10. /**
  11. * 根据用户限流。例:此方法1分钟只允许此用户访问n次
  12. */
  13. USER,
  14. /**
  15. * 自定义限流方法
  16. */
  17. CUSTOM
  18. }

定义切面进行拦截和判断限流

  1. @Slf4j
  2. @Aspect
  3. @Component
  4. public class MethodAnnotationAspect {
  5. private static final Logger BLACK_LOG = LoggerFactory.getLogger("BLACK_LOG");
  6. @Autowired
  7. private RateLimiterAlgorithm rateLimiterAlgorithm;
  8. @Autowired
  9. private RedisTemplate redisTemplate;
  10. @Autowired
  11. private PropResource propResource;
  12. @Pointcut("@annotation(methodRateLimit)")
  13. public void annotationPointcut(MethodRateLimit methodRateLimit) {
  14. }
  15. @Before("annotationPointcut(methodRateLimit)")
  16. public void doBefore(JoinPoint joinPoint, MethodRateLimit methodRateLimit) {
  17. String blackKey = null;
  18. String key = null;
  19. // 从配置中心获取限流的配置
  20. Map<String, Object> config = getRateLimitConfig();
  21. UserVo user = HttpContext.getCurrentUser();
  22. // 默认的配置
  23. long limit = methodRateLimit.limit();
  24. long refreshInterval = methodRateLimit.refreshInterval();
  25. long blackExpire = methodRateLimit.blackExpire();
  26. CheckTypeEnum checkTypeEnum = methodRateLimit.checkType();
  27. try {
  28. if (config.containsKey("limit") && (Long) config.get("limit") > 0) {
  29. limit = (Long) config.get("limit");
  30. }
  31. if (config.containsKey("refreshInterval") && (Long) config.get("refreshInterval") > 0) {
  32. refreshInterval = (Long) config.get("refreshInterval");
  33. }
  34. // 获取黑名单的 redis key「propResource.getRedisIsCluster() 是否是集群,集群模式要指定 slot 的 hash tag」@more
  35. blackKey = RateLimiterUtil.getBlackKey(joinPoint, checkTypeEnum, propResource.getRedisIsCluster());
  36. if (redisTemplate.hasKey(blackKey)) {
  37. log.info("{} 在黑名单上", blackKey);
  38. throw new LimitRateException(LimitRateErrorEnum.BLACK_REQUESTS);
  39. }
  40. // 从 redis 获取限流 key
  41. key = RateLimiterUtil.getRateKey(joinPoint, checkTypeEnum, propResource.getRedisIsCluster());
  42. log.debug("限流key:{}", key);
  43. // 通过 lua 脚本判定是否限流和更新 redis 操作 @more
  44. rateLimiterAlgorithm.consume(key, limit, refreshInterval * 60);
  45. } catch (LimitRateException e) {
  46. if (LimitRateErrorEnum.TOO_MANY_REQUESTS.getCode().equals(e.getCode()) && blackKey != null) {
  47. log.info("{} 加入黑名单", blackKey);
  48. if (config.containsKey("blackExpire") && config.get("blackExpire") != null && (long) config.get("blackExpire") > 0) {
  49. blackExpire = (long) config.get("blackExpire");
  50. }
  51. // 设置黑名单
  52. redisTemplate.opsForValue().set(blackKey, 1, blackExpire, TimeUnit.MINUTES);
  53. // 创建日志
  54. Object blackLog = createBlackLog(checkTypeEnum, blackKey, key, user, joinPoint);
  55. if (blackLog != null) {
  56. BLACK_LOG.info(JSON.toJSONStringWithDateFormat(blackLog, "yyyy-MM-dd HH:mm:ss"));
  57. }
  58. throw e;
  59. } else if (LimitRateErrorEnum.BLACK_REQUESTS.getCode() == e.getCode()) {
  60. log.info("{} 在黑名单上", blackKey);
  61. throw e;
  62. }
  63. } catch (Exception e) {
  64. log.error("限流错误", e);
  65. }
  66. }
  67. public Object createBlackLog(CheckTypeEnum checkTypeEnum, String checkKey, String blackKey, UserVo user, JoinPoint joinPoint) {
  68. try {
  69. LogVo logVo = new LogVo();
  70. Map<String, Object> blackMap = new HashMap<>();
  71. blackMap.put("createTime", new Date());
  72. blackMap.put("checkType", checkTypeEnum.name());
  73. blackMap.put("rateKey", checkKey);
  74. blackMap.put("blackKey", blackKey);
  75. blackMap.put("checkKey", RateLimiterUtil.getCheckKey(joinPoint, checkTypeEnum));
  76. if (user != null) {
  77. blackMap.put("organizationCode", user.getOrganizationCode());
  78. }
  79. logVo.setSystemName("cloud-search-black");
  80. logVo.setSuccess(true);
  81. logVo.setCode(Result.DEFAULT_SUCCESS_CODE);
  82. logVo.setTraceId(StringUtils.getId());
  83. ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
  84. HttpServletRequest request = requestAttributes.getRequest();
  85. logVo.setIp(IpUtils.getRemortIP(request));
  86. logVo.setExtendedParam(blackMap);
  87. logVo.setStartTime(new Date());
  88. logVo.setEndTime(new Date());
  89. return logVo;
  90. } catch (Exception e) {
  91. log.error("生成日志失败");
  92. return null;
  93. }
  94. }
  95. public Map getRateLimitConfig() {
  96. Map<String, Object> map = new HashMap<>();
  97. map.put("refreshInterval", 60L);
  98. map.put("limit", 3L);
  99. map.put("blackExpire", 10L);
  100. return map;
  101. }
  102. }

根据不同策略获取 redis key

  1. @Slf4j
  2. public class RateLimiterUtil {
  3. /**
  4. * 获取黑名单
  5. *
  6. * @param joinPoint
  7. * @param checkTypeEnum 策略类型
  8. * @param isCluster 是否是集群模式
  9. * @return
  10. */
  11. public static String getBlackKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum, boolean isCluster) {
  12. StringBuffer key = new StringBuffer();
  13. if (isCluster) {
  14. key.append(LimitConstants.HASH_TAG_PRFIX).append(LimitConstants.HASH_TAG).append(LimitConstants.HASH_TAG_SUFFIX);
  15. } else {
  16. key.append(LimitConstants.HASH_TAG);
  17. }
  18. String appId = SpringContextHolder.getBean(Environment.class).getProperty("app.id");
  19. key.append(appId).append(":black:");
  20. key.append(checkTypeEnum.name().toLowerCase()).append(":");
  21. key.append(getCheckKey(joinPoint, checkTypeEnum));
  22. return key.toString();
  23. }
  24. /**
  25. * 获取唯一标识此次请求的key
  26. *
  27. * @param joinPoint 切点
  28. * @param checkTypeEnum 枚举
  29. * @return key
  30. */
  31. public static String getRateKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum, boolean isCluster) {
  32. StringBuffer key = new StringBuffer();
  33. if (isCluster) {
  34. key.append(LimitConstants.HASH_TAG_PRFIX).append(LimitConstants.HASH_TAG).append(LimitConstants.HASH_TAG_SUFFIX);
  35. } else {
  36. key.append(LimitConstants.HASH_TAG);
  37. }
  38. String appId = SpringContextHolder.getBean(Environment.class).getProperty("app.id");
  39. key.append(appId).append(":rate:key:");
  40. key.append(checkTypeEnum.name().toLowerCase()).append(":");
  41. key.append(getCheckKey(joinPoint, checkTypeEnum));
  42. return key.toString();
  43. }
  44. /**
  45. * 根据策略类型获取 key 值
  46. *
  47. * @param joinPoint
  48. * @param checkTypeEnum
  49. * @return
  50. */
  51. public static String getCheckKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum) {
  52. StringBuffer key = new StringBuffer();
  53. /**
  54. * 第一种:所有请求统一限流:以方法名加参数列表作为唯一标识方法的key
  55. */
  56. if (CheckTypeEnum.ALL.equals(checkTypeEnum)) {
  57. MethodSignature signature = (MethodSignature) joinPoint.getSignature();
  58. key.append(signature.getMethod().getName());
  59. Class[] parameterTypes = signature.getParameterTypes();
  60. for (Class clazz : parameterTypes) {
  61. key.append(clazz.getName());
  62. }
  63. key.append(joinPoint.getTarget().getClass());
  64. }
  65. ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
  66. HttpServletRequest request = requestAttributes.getRequest();
  67. /**
  68. * 第二种:根据用户限流:以用户信息作为key
  69. */
  70. if (CheckTypeEnum.USER.equals(checkTypeEnum)) {
  71. UserVo userVo = HttpContext.getCurrentUser();
  72. if (request.getUserPrincipal() != null) {
  73. key.append(request.getUserPrincipal().getName());
  74. } else if (userVo != null) {
  75. if (StringUtils.isNotBlank(userVo.getOrganizationCode())) {
  76. key.append(userVo.getOrganizationCode()).append(":");
  77. }
  78. if (StringUtils.isNotBlank(userVo.getSceneCode())) {
  79. key.append(userVo.getSceneCode()).append(":");
  80. }
  81. if (StringUtils.isNotBlank(userVo.getIdentityCode())) {
  82. key.append(userVo.getIdentityCode()).append(":");
  83. }
  84. key.append(userVo.getId());
  85. } else {
  86. throw new LimitRateException(LimitRateErrorEnum.USER_NOT_DOUND);
  87. }
  88. }
  89. /**
  90. * 第三种:根据IP限流:以IP地址作为key
  91. */
  92. if (CheckTypeEnum.IP.equals(checkTypeEnum)) {
  93. String ip = IpUtils.getRemortIP(request);
  94. if (ip != null) {
  95. key.append(getIpAddr(request));
  96. } else {
  97. throw new LimitRateException(LimitRateErrorEnum.IP_ERROR);
  98. }
  99. }
  100. /**
  101. * 第四种:自定义限流方法:以自定义内容作为key
  102. */
  103. if (CheckTypeEnum.CUSTOM.equals(checkTypeEnum)) {
  104. if (request.getAttribute(LimitConstants.CUSTOM) != null) {
  105. key.append(request.getAttribute(LimitConstants.CUSTOM).toString());
  106. } else {
  107. throw new LimitRateException(LimitRateErrorEnum.CUSTOM_NOT_DOUND);
  108. }
  109. }
  110. return key.toString();
  111. }
  112. /**
  113. * 获取当前网络ip
  114. *
  115. * @param request HttpServletRequest
  116. * @return ip
  117. */
  118. public static String getIpAddr(HttpServletRequest request) {
  119. String ipAddress = request.getHeader("x-forwarded-for");
  120. if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
  121. ipAddress = request.getHeader("Proxy-Client-IP");
  122. }
  123. if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
  124. ipAddress = request.getHeader("WL-Proxy-Client-IP");
  125. }
  126. if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
  127. ipAddress = request.getRemoteAddr();
  128. if (ipAddress.equals("127.0.0.1") || ipAddress.equals("0:0:0:0:0:0:0:1")) {
  129. //根据网卡取本机配置的IP
  130. InetAddress inet = null;
  131. try {
  132. inet = InetAddress.getLocalHost();
  133. } catch (UnknownHostException e) {
  134. log.error("获取IP地址失败", e);
  135. }
  136. if (inet != null) {
  137. ipAddress = inet.getHostAddress();
  138. }
  139. }
  140. }
  141. if (ipAddress != null && ipAddress.length() > 15) {
  142. if (ipAddress.indexOf(",") > 0) {
  143. ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
  144. }
  145. }
  146. return ipAddress;
  147. }
  148. }

自定义异常和异常枚举类

自定义限流异常

  1. @Data
  2. public class LimitRateException extends RuntimeException{
  3. private String msg;
  4. private Integer code;
  5. public LimitRateException(LimitRateErrorEnum error) {
  6. super(error.getMsg());
  7. this.msg = error.getMsg();
  8. this.code = error.getCode();
  9. }
  10. public LimitRateException(String msg, Integer code){
  11. super(msg);
  12. this.msg = msg;
  13. this.code = code;
  14. }
  15. }

限流异常通用错误信息 - 枚举类

  1. public enum LimitRateErrorEnum {
  2. TOO_MANY_REQUESTS(50001, "personal-resource-rateLimit say: You have made too many requests,please try again later!!!"),
  3. USER_NOT_DOUND(50002, "personal-resource-rateLimit say: not found user info ,please check request.getUserPrincipal().getName()!!!"),
  4. CUSTOM_NOT_DOUND(50003, "personal-resource-rateLimit say: not found custom info ,please check request.getAttribute('syj-rateLimit-custom')!!!"),
  5. BLACK_REQUESTS(50004, "personal-resource-rateLimit say: You are in black list"),
  6. IP_ERROR(50005, "personal-resource-rateLimit say: ip error");
  7. private final String msg;
  8. private final Integer code;
  9. LimitRateErrorEnum(Integer code, String msg){
  10. this.msg = msg;
  11. this.code = code;
  12. }
  13. public String getMsg() {
  14. return msg;
  15. }
  16. public Integer getCode() {
  17. return code;
  18. }
  19. }

限速算法

先定义一个接口

  1. public interface RateLimiterAlgorithm {
  2. /**
  3. * @param key key
  4. * @param limit 限制次数
  5. * @param refreshInterval 限流时间间隔
  6. */
  7. void consume(String key, long limit, long refreshInterval);
  8. }

定义一个抽象类

  1. public abstract class RateLimiter {
  2. public void counterConsume(String key, long limit, long lrefreshInterval){
  3. }
  4. public void tokenConsume(String key, long limit, long lrefreshInterval, long tokenBucketStepNum, long tokenBucketTimeInterval){
  5. }
  6. }

接口的实现类

  1. @Service
  2. @DependsOn("rateLimiter")
  3. @RequiredArgsConstructor
  4. @ConditionalOnProperty(prefix = LimitConstants.PREFIX, name = "algorithm", havingValue = "counter", matchIfMissing = true)
  5. public class CounterAlgorithmImpl implements RateLimiterAlgorithm {
  6. @NonNull
  7. private RateLimiter rateLimiter;
  8. public void consume(String key, long limit, long lrefreshInterval){
  9. rateLimiter.counterConsume(key,limit,lrefreshInterval);
  10. }
  11. }

抽象类的子类

  1. @Slf4j
  2. public class RedisRateLimiterCounterImpl extends RateLimiter {
  3. @Autowired
  4. private RedisTemplate redisTemplate;
  5. // lua 脚本
  6. private DefaultRedisScript<Long> redisScript;
  7. // 通过构造器传入 lua
  8. public RedisRateLimiterCounterImpl(DefaultRedisScript redisScript){
  9. this.redisScript = redisScript;
  10. }
  11. public void counterConsume(String key, long limit, long lrefreshInterval) throws LimitRateException {
  12. List<Object> keyList = new ArrayList<>();
  13. keyList.add(key);
  14. log.debug("限流传参:{}", JSON.toJSONString(keyList));
  15. // 参数分别为:redisScript | 参数编解码器 | 返回值编解码器 | key集合 | Object... args (在 lua 脚本中会以数组接收)
  16. String result = redisTemplate.execute(redisScript, new StringRedisSerializer(), new StringRedisSerializer(), keyList, limit+"", lrefreshInterval+"").toString();
  17. if(LimitConstants.REDIS_ERROR.equals(result)){
  18. throw new LimitRateException(LimitRateErrorEnum.TOO_MANY_REQUESTS);
  19. }
  20. }
  21. }

配置类,在容器启动时加载

  1. @Slf4j
  2. @Configuration
  3. public class EnableRateLimitConfiguration {
  4. @Bean
  5. @ConditionalOnMissingBean(name = "redisTemplate")
  6. public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory, RedisSerializer fastJson2JsonRedisSerializer){
  7. RedisTemplate<Object, Object> template = new RedisTemplate<>();
  8. template.setConnectionFactory(connectionFactory);
  9. template.setKeySerializer(new StringRedisSerializer());
  10. template.setValueSerializer(fastJson2JsonRedisSerializer);
  11. template.afterPropertiesSet();
  12. return template;
  13. }
  14. @Bean
  15. @ConditionalOnMissingBean(name = "fastJson2JsonRedisSerializer")
  16. public RedisSerializer fastJson2JsonRedisSerializer() {
  17. return new FastJsonRedisSerializer<>(Object.class);
  18. }
  19. @Bean(name = "rateLimiter")
  20. public RateLimiter tokenRateLimiter(){
  21. DefaultRedisScript<Long> consumeRedisScript = new DefaultRedisScript<>();
  22. consumeRedisScript.setResultType(Long.class);
  23. consumeRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/redis-ratelimiter-counter.lua")));
  24. return new RedisRateLimiterCounterImpl(consumeRedisScript);
  25. }
  26. }

lua 脚本

redis-ratelimiter-counter.lua

  1. local key = KEYS[1];
  2. -- 最大限制
  3. local limit = tonumber(ARGV[1]);
  4. -- 过期时间
  5. local expire = tonumber(ARGV[2])
  6. -- 根据key判断是否存在
  7. local hasKey = redis.call('EXISTS', KEYS[1]);
  8. if hasKey == 1 then
  9. -- 根据key获取val(次数)
  10. local value = tonumber(redis.call('GET', KEYS[1]));
  11. if value >= limit then
  12. return -1;
  13. end
  14. end
  15. -- 次数自增
  16. redis.call('INCR', KEYS[1]);
  17. -- 获取ttl
  18. local ttl = redis.call('TTL', KEYS[1]);
  19. if ttl < 0 then
  20. -- 超时
  21. redis.call('EXPIRE', KEYS[1], expire);
  22. end
  23. return 1;