【目标】通过 **自定义注解 + lua 脚本 + 多种维度** 实现接口限流。
自定义限流注解
@Target(ElementType.METHOD)@Retention(RetentionPolicy.RUNTIME)@Documented@Inheritedpublic @interface MethodRateLimit {/*** @return CheckTypeEnum 限流类型。默认值:ALL。可选值:ALL,IP,USER,CUSTOM*/CheckTypeEnum checkType() default CheckTypeEnum.ALL;/*** @return 限流次数。默认值60*/long limit() default 60;/*** @return 限流时间间隔,以秒为单位。默认值60*/long refreshInterval() default 60;/*** 黑名单时间, 单位分钟*/long blackExpire() default 10;}
CheckTypeEnum 是限流策略的枚举类:
public enum CheckTypeEnum {/*** 所有请求统一限流。例:此方法1分钟只允许访问n次*/ALL,/*** 根据IP限流。例:此方法1分钟只允许此IP访问n次*/IP,/*** 根据用户限流。例:此方法1分钟只允许此用户访问n次*/USER,/*** 自定义限流方法*/CUSTOM}
定义切面进行拦截和判断限流
@Slf4j@Aspect@Componentpublic class MethodAnnotationAspect {private static final Logger BLACK_LOG = LoggerFactory.getLogger("BLACK_LOG");@Autowiredprivate RateLimiterAlgorithm rateLimiterAlgorithm;@Autowiredprivate RedisTemplate redisTemplate;@Autowiredprivate PropResource propResource;@Pointcut("@annotation(methodRateLimit)")public void annotationPointcut(MethodRateLimit methodRateLimit) {}@Before("annotationPointcut(methodRateLimit)")public void doBefore(JoinPoint joinPoint, MethodRateLimit methodRateLimit) {String blackKey = null;String key = null;// 从配置中心获取限流的配置Map<String, Object> config = getRateLimitConfig();UserVo user = HttpContext.getCurrentUser();// 默认的配置long limit = methodRateLimit.limit();long refreshInterval = methodRateLimit.refreshInterval();long blackExpire = methodRateLimit.blackExpire();CheckTypeEnum checkTypeEnum = methodRateLimit.checkType();try {if (config.containsKey("limit") && (Long) config.get("limit") > 0) {limit = (Long) config.get("limit");}if (config.containsKey("refreshInterval") && (Long) config.get("refreshInterval") > 0) {refreshInterval = (Long) config.get("refreshInterval");}// 获取黑名单的 redis key「propResource.getRedisIsCluster() 是否是集群,集群模式要指定 slot 的 hash tag」@moreblackKey = RateLimiterUtil.getBlackKey(joinPoint, checkTypeEnum, propResource.getRedisIsCluster());if (redisTemplate.hasKey(blackKey)) {log.info("{} 在黑名单上", blackKey);throw new LimitRateException(LimitRateErrorEnum.BLACK_REQUESTS);}// 从 redis 获取限流 keykey = RateLimiterUtil.getRateKey(joinPoint, checkTypeEnum, propResource.getRedisIsCluster());log.debug("限流key:{}", key);// 通过 lua 脚本判定是否限流和更新 redis 操作 @morerateLimiterAlgorithm.consume(key, limit, refreshInterval * 60);} catch (LimitRateException e) {if (LimitRateErrorEnum.TOO_MANY_REQUESTS.getCode().equals(e.getCode()) && blackKey != null) {log.info("{} 加入黑名单", blackKey);if (config.containsKey("blackExpire") && config.get("blackExpire") != null && (long) config.get("blackExpire") > 0) {blackExpire = (long) config.get("blackExpire");}// 设置黑名单redisTemplate.opsForValue().set(blackKey, 1, blackExpire, TimeUnit.MINUTES);// 创建日志Object blackLog = createBlackLog(checkTypeEnum, blackKey, key, user, joinPoint);if (blackLog != null) {BLACK_LOG.info(JSON.toJSONStringWithDateFormat(blackLog, "yyyy-MM-dd HH:mm:ss"));}throw e;} else if (LimitRateErrorEnum.BLACK_REQUESTS.getCode() == e.getCode()) {log.info("{} 在黑名单上", blackKey);throw e;}} catch (Exception e) {log.error("限流错误", e);}}public Object createBlackLog(CheckTypeEnum checkTypeEnum, String checkKey, String blackKey, UserVo user, JoinPoint joinPoint) {try {LogVo logVo = new LogVo();Map<String, Object> blackMap = new HashMap<>();blackMap.put("createTime", new Date());blackMap.put("checkType", checkTypeEnum.name());blackMap.put("rateKey", checkKey);blackMap.put("blackKey", blackKey);blackMap.put("checkKey", RateLimiterUtil.getCheckKey(joinPoint, checkTypeEnum));if (user != null) {blackMap.put("organizationCode", user.getOrganizationCode());}logVo.setSystemName("cloud-search-black");logVo.setSuccess(true);logVo.setCode(Result.DEFAULT_SUCCESS_CODE);logVo.setTraceId(StringUtils.getId());ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();HttpServletRequest request = requestAttributes.getRequest();logVo.setIp(IpUtils.getRemortIP(request));logVo.setExtendedParam(blackMap);logVo.setStartTime(new Date());logVo.setEndTime(new Date());return logVo;} catch (Exception e) {log.error("生成日志失败");return null;}}public Map getRateLimitConfig() {Map<String, Object> map = new HashMap<>();map.put("refreshInterval", 60L);map.put("limit", 3L);map.put("blackExpire", 10L);return map;}}
根据不同策略获取 redis key
@Slf4jpublic class RateLimiterUtil {/*** 获取黑名单** @param joinPoint* @param checkTypeEnum 策略类型* @param isCluster 是否是集群模式* @return*/public static String getBlackKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum, boolean isCluster) {StringBuffer key = new StringBuffer();if (isCluster) {key.append(LimitConstants.HASH_TAG_PRFIX).append(LimitConstants.HASH_TAG).append(LimitConstants.HASH_TAG_SUFFIX);} else {key.append(LimitConstants.HASH_TAG);}String appId = SpringContextHolder.getBean(Environment.class).getProperty("app.id");key.append(appId).append(":black:");key.append(checkTypeEnum.name().toLowerCase()).append(":");key.append(getCheckKey(joinPoint, checkTypeEnum));return key.toString();}/*** 获取唯一标识此次请求的key** @param joinPoint 切点* @param checkTypeEnum 枚举* @return key*/public static String getRateKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum, boolean isCluster) {StringBuffer key = new StringBuffer();if (isCluster) {key.append(LimitConstants.HASH_TAG_PRFIX).append(LimitConstants.HASH_TAG).append(LimitConstants.HASH_TAG_SUFFIX);} else {key.append(LimitConstants.HASH_TAG);}String appId = SpringContextHolder.getBean(Environment.class).getProperty("app.id");key.append(appId).append(":rate:key:");key.append(checkTypeEnum.name().toLowerCase()).append(":");key.append(getCheckKey(joinPoint, checkTypeEnum));return key.toString();}/*** 根据策略类型获取 key 值** @param joinPoint* @param checkTypeEnum* @return*/public static String getCheckKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum) {StringBuffer key = new StringBuffer();/*** 第一种:所有请求统一限流:以方法名加参数列表作为唯一标识方法的key*/if (CheckTypeEnum.ALL.equals(checkTypeEnum)) {MethodSignature signature = (MethodSignature) joinPoint.getSignature();key.append(signature.getMethod().getName());Class[] parameterTypes = signature.getParameterTypes();for (Class clazz : parameterTypes) {key.append(clazz.getName());}key.append(joinPoint.getTarget().getClass());}ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();HttpServletRequest request = requestAttributes.getRequest();/*** 第二种:根据用户限流:以用户信息作为key*/if (CheckTypeEnum.USER.equals(checkTypeEnum)) {UserVo userVo = HttpContext.getCurrentUser();if (request.getUserPrincipal() != null) {key.append(request.getUserPrincipal().getName());} else if (userVo != null) {if (StringUtils.isNotBlank(userVo.getOrganizationCode())) {key.append(userVo.getOrganizationCode()).append(":");}if (StringUtils.isNotBlank(userVo.getSceneCode())) {key.append(userVo.getSceneCode()).append(":");}if (StringUtils.isNotBlank(userVo.getIdentityCode())) {key.append(userVo.getIdentityCode()).append(":");}key.append(userVo.getId());} else {throw new LimitRateException(LimitRateErrorEnum.USER_NOT_DOUND);}}/*** 第三种:根据IP限流:以IP地址作为key*/if (CheckTypeEnum.IP.equals(checkTypeEnum)) {String ip = IpUtils.getRemortIP(request);if (ip != null) {key.append(getIpAddr(request));} else {throw new LimitRateException(LimitRateErrorEnum.IP_ERROR);}}/*** 第四种:自定义限流方法:以自定义内容作为key*/if (CheckTypeEnum.CUSTOM.equals(checkTypeEnum)) {if (request.getAttribute(LimitConstants.CUSTOM) != null) {key.append(request.getAttribute(LimitConstants.CUSTOM).toString());} else {throw new LimitRateException(LimitRateErrorEnum.CUSTOM_NOT_DOUND);}}return key.toString();}/*** 获取当前网络ip** @param request HttpServletRequest* @return ip*/public static String getIpAddr(HttpServletRequest request) {String ipAddress = request.getHeader("x-forwarded-for");if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {ipAddress = request.getHeader("Proxy-Client-IP");}if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {ipAddress = request.getHeader("WL-Proxy-Client-IP");}if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {ipAddress = request.getRemoteAddr();if (ipAddress.equals("127.0.0.1") || ipAddress.equals("0:0:0:0:0:0:0:1")) {//根据网卡取本机配置的IPInetAddress inet = null;try {inet = InetAddress.getLocalHost();} catch (UnknownHostException e) {log.error("获取IP地址失败", e);}if (inet != null) {ipAddress = inet.getHostAddress();}}}if (ipAddress != null && ipAddress.length() > 15) {if (ipAddress.indexOf(",") > 0) {ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));}}return ipAddress;}}
自定义异常和异常枚举类
自定义限流异常
@Datapublic class LimitRateException extends RuntimeException{private String msg;private Integer code;public LimitRateException(LimitRateErrorEnum error) {super(error.getMsg());this.msg = error.getMsg();this.code = error.getCode();}public LimitRateException(String msg, Integer code){super(msg);this.msg = msg;this.code = code;}}
限流异常通用错误信息 - 枚举类
public enum LimitRateErrorEnum {TOO_MANY_REQUESTS(50001, "personal-resource-rateLimit say: You have made too many requests,please try again later!!!"),USER_NOT_DOUND(50002, "personal-resource-rateLimit say: not found user info ,please check request.getUserPrincipal().getName()!!!"),CUSTOM_NOT_DOUND(50003, "personal-resource-rateLimit say: not found custom info ,please check request.getAttribute('syj-rateLimit-custom')!!!"),BLACK_REQUESTS(50004, "personal-resource-rateLimit say: You are in black list"),IP_ERROR(50005, "personal-resource-rateLimit say: ip error");private final String msg;private final Integer code;LimitRateErrorEnum(Integer code, String msg){this.msg = msg;this.code = code;}public String getMsg() {return msg;}public Integer getCode() {return code;}}
限速算法
先定义一个接口
public interface RateLimiterAlgorithm {/*** @param key key* @param limit 限制次数* @param refreshInterval 限流时间间隔*/void consume(String key, long limit, long refreshInterval);}
定义一个抽象类
public abstract class RateLimiter {public void counterConsume(String key, long limit, long lrefreshInterval){}public void tokenConsume(String key, long limit, long lrefreshInterval, long tokenBucketStepNum, long tokenBucketTimeInterval){}}
接口的实现类
@Service@DependsOn("rateLimiter")@RequiredArgsConstructor@ConditionalOnProperty(prefix = LimitConstants.PREFIX, name = "algorithm", havingValue = "counter", matchIfMissing = true)public class CounterAlgorithmImpl implements RateLimiterAlgorithm {@NonNullprivate RateLimiter rateLimiter;public void consume(String key, long limit, long lrefreshInterval){rateLimiter.counterConsume(key,limit,lrefreshInterval);}}
抽象类的子类
@Slf4jpublic class RedisRateLimiterCounterImpl extends RateLimiter {@Autowiredprivate RedisTemplate redisTemplate;// lua 脚本private DefaultRedisScript<Long> redisScript;// 通过构造器传入 luapublic RedisRateLimiterCounterImpl(DefaultRedisScript redisScript){this.redisScript = redisScript;}public void counterConsume(String key, long limit, long lrefreshInterval) throws LimitRateException {List<Object> keyList = new ArrayList<>();keyList.add(key);log.debug("限流传参:{}", JSON.toJSONString(keyList));// 参数分别为:redisScript | 参数编解码器 | 返回值编解码器 | key集合 | Object... args (在 lua 脚本中会以数组接收)String result = redisTemplate.execute(redisScript, new StringRedisSerializer(), new StringRedisSerializer(), keyList, limit+"", lrefreshInterval+"").toString();if(LimitConstants.REDIS_ERROR.equals(result)){throw new LimitRateException(LimitRateErrorEnum.TOO_MANY_REQUESTS);}}}
配置类,在容器启动时加载
@Slf4j@Configurationpublic class EnableRateLimitConfiguration {@Bean@ConditionalOnMissingBean(name = "redisTemplate")public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory, RedisSerializer fastJson2JsonRedisSerializer){RedisTemplate<Object, Object> template = new RedisTemplate<>();template.setConnectionFactory(connectionFactory);template.setKeySerializer(new StringRedisSerializer());template.setValueSerializer(fastJson2JsonRedisSerializer);template.afterPropertiesSet();return template;}@Bean@ConditionalOnMissingBean(name = "fastJson2JsonRedisSerializer")public RedisSerializer fastJson2JsonRedisSerializer() {return new FastJsonRedisSerializer<>(Object.class);}@Bean(name = "rateLimiter")public RateLimiter tokenRateLimiter(){DefaultRedisScript<Long> consumeRedisScript = new DefaultRedisScript<>();consumeRedisScript.setResultType(Long.class);consumeRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/redis-ratelimiter-counter.lua")));return new RedisRateLimiterCounterImpl(consumeRedisScript);}}
lua 脚本
redis-ratelimiter-counter.lua
local key = KEYS[1];-- 最大限制local limit = tonumber(ARGV[1]);-- 过期时间local expire = tonumber(ARGV[2])-- 根据key判断是否存在local hasKey = redis.call('EXISTS', KEYS[1]);if hasKey == 1 then-- 根据key获取val(次数)local value = tonumber(redis.call('GET', KEYS[1]));if value >= limit thenreturn -1;endend-- 次数自增redis.call('INCR', KEYS[1]);-- 获取ttllocal ttl = redis.call('TTL', KEYS[1]);if ttl < 0 then-- 超时redis.call('EXPIRE', KEYS[1], expire);endreturn 1;
