【目标】通过 **自定义注解 + lua 脚本 + 多种维度**
实现接口限流。
自定义限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface MethodRateLimit {
/**
* @return CheckTypeEnum 限流类型。默认值:ALL。可选值:ALL,IP,USER,CUSTOM
*/
CheckTypeEnum checkType() default CheckTypeEnum.ALL;
/**
* @return 限流次数。默认值60
*/
long limit() default 60;
/**
* @return 限流时间间隔,以秒为单位。默认值60
*/
long refreshInterval() default 60;
/**
* 黑名单时间, 单位分钟
*/
long blackExpire() default 10;
}
CheckTypeEnum
是限流策略的枚举类:
public enum CheckTypeEnum {
/**
* 所有请求统一限流。例:此方法1分钟只允许访问n次
*/
ALL,
/**
* 根据IP限流。例:此方法1分钟只允许此IP访问n次
*/
IP,
/**
* 根据用户限流。例:此方法1分钟只允许此用户访问n次
*/
USER,
/**
* 自定义限流方法
*/
CUSTOM
}
定义切面进行拦截和判断限流
@Slf4j
@Aspect
@Component
public class MethodAnnotationAspect {
private static final Logger BLACK_LOG = LoggerFactory.getLogger("BLACK_LOG");
@Autowired
private RateLimiterAlgorithm rateLimiterAlgorithm;
@Autowired
private RedisTemplate redisTemplate;
@Autowired
private PropResource propResource;
@Pointcut("@annotation(methodRateLimit)")
public void annotationPointcut(MethodRateLimit methodRateLimit) {
}
@Before("annotationPointcut(methodRateLimit)")
public void doBefore(JoinPoint joinPoint, MethodRateLimit methodRateLimit) {
String blackKey = null;
String key = null;
// 从配置中心获取限流的配置
Map<String, Object> config = getRateLimitConfig();
UserVo user = HttpContext.getCurrentUser();
// 默认的配置
long limit = methodRateLimit.limit();
long refreshInterval = methodRateLimit.refreshInterval();
long blackExpire = methodRateLimit.blackExpire();
CheckTypeEnum checkTypeEnum = methodRateLimit.checkType();
try {
if (config.containsKey("limit") && (Long) config.get("limit") > 0) {
limit = (Long) config.get("limit");
}
if (config.containsKey("refreshInterval") && (Long) config.get("refreshInterval") > 0) {
refreshInterval = (Long) config.get("refreshInterval");
}
// 获取黑名单的 redis key「propResource.getRedisIsCluster() 是否是集群,集群模式要指定 slot 的 hash tag」@more
blackKey = RateLimiterUtil.getBlackKey(joinPoint, checkTypeEnum, propResource.getRedisIsCluster());
if (redisTemplate.hasKey(blackKey)) {
log.info("{} 在黑名单上", blackKey);
throw new LimitRateException(LimitRateErrorEnum.BLACK_REQUESTS);
}
// 从 redis 获取限流 key
key = RateLimiterUtil.getRateKey(joinPoint, checkTypeEnum, propResource.getRedisIsCluster());
log.debug("限流key:{}", key);
// 通过 lua 脚本判定是否限流和更新 redis 操作 @more
rateLimiterAlgorithm.consume(key, limit, refreshInterval * 60);
} catch (LimitRateException e) {
if (LimitRateErrorEnum.TOO_MANY_REQUESTS.getCode().equals(e.getCode()) && blackKey != null) {
log.info("{} 加入黑名单", blackKey);
if (config.containsKey("blackExpire") && config.get("blackExpire") != null && (long) config.get("blackExpire") > 0) {
blackExpire = (long) config.get("blackExpire");
}
// 设置黑名单
redisTemplate.opsForValue().set(blackKey, 1, blackExpire, TimeUnit.MINUTES);
// 创建日志
Object blackLog = createBlackLog(checkTypeEnum, blackKey, key, user, joinPoint);
if (blackLog != null) {
BLACK_LOG.info(JSON.toJSONStringWithDateFormat(blackLog, "yyyy-MM-dd HH:mm:ss"));
}
throw e;
} else if (LimitRateErrorEnum.BLACK_REQUESTS.getCode() == e.getCode()) {
log.info("{} 在黑名单上", blackKey);
throw e;
}
} catch (Exception e) {
log.error("限流错误", e);
}
}
public Object createBlackLog(CheckTypeEnum checkTypeEnum, String checkKey, String blackKey, UserVo user, JoinPoint joinPoint) {
try {
LogVo logVo = new LogVo();
Map<String, Object> blackMap = new HashMap<>();
blackMap.put("createTime", new Date());
blackMap.put("checkType", checkTypeEnum.name());
blackMap.put("rateKey", checkKey);
blackMap.put("blackKey", blackKey);
blackMap.put("checkKey", RateLimiterUtil.getCheckKey(joinPoint, checkTypeEnum));
if (user != null) {
blackMap.put("organizationCode", user.getOrganizationCode());
}
logVo.setSystemName("cloud-search-black");
logVo.setSuccess(true);
logVo.setCode(Result.DEFAULT_SUCCESS_CODE);
logVo.setTraceId(StringUtils.getId());
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = requestAttributes.getRequest();
logVo.setIp(IpUtils.getRemortIP(request));
logVo.setExtendedParam(blackMap);
logVo.setStartTime(new Date());
logVo.setEndTime(new Date());
return logVo;
} catch (Exception e) {
log.error("生成日志失败");
return null;
}
}
public Map getRateLimitConfig() {
Map<String, Object> map = new HashMap<>();
map.put("refreshInterval", 60L);
map.put("limit", 3L);
map.put("blackExpire", 10L);
return map;
}
}
根据不同策略获取 redis key
@Slf4j
public class RateLimiterUtil {
/**
* 获取黑名单
*
* @param joinPoint
* @param checkTypeEnum 策略类型
* @param isCluster 是否是集群模式
* @return
*/
public static String getBlackKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum, boolean isCluster) {
StringBuffer key = new StringBuffer();
if (isCluster) {
key.append(LimitConstants.HASH_TAG_PRFIX).append(LimitConstants.HASH_TAG).append(LimitConstants.HASH_TAG_SUFFIX);
} else {
key.append(LimitConstants.HASH_TAG);
}
String appId = SpringContextHolder.getBean(Environment.class).getProperty("app.id");
key.append(appId).append(":black:");
key.append(checkTypeEnum.name().toLowerCase()).append(":");
key.append(getCheckKey(joinPoint, checkTypeEnum));
return key.toString();
}
/**
* 获取唯一标识此次请求的key
*
* @param joinPoint 切点
* @param checkTypeEnum 枚举
* @return key
*/
public static String getRateKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum, boolean isCluster) {
StringBuffer key = new StringBuffer();
if (isCluster) {
key.append(LimitConstants.HASH_TAG_PRFIX).append(LimitConstants.HASH_TAG).append(LimitConstants.HASH_TAG_SUFFIX);
} else {
key.append(LimitConstants.HASH_TAG);
}
String appId = SpringContextHolder.getBean(Environment.class).getProperty("app.id");
key.append(appId).append(":rate:key:");
key.append(checkTypeEnum.name().toLowerCase()).append(":");
key.append(getCheckKey(joinPoint, checkTypeEnum));
return key.toString();
}
/**
* 根据策略类型获取 key 值
*
* @param joinPoint
* @param checkTypeEnum
* @return
*/
public static String getCheckKey(JoinPoint joinPoint, CheckTypeEnum checkTypeEnum) {
StringBuffer key = new StringBuffer();
/**
* 第一种:所有请求统一限流:以方法名加参数列表作为唯一标识方法的key
*/
if (CheckTypeEnum.ALL.equals(checkTypeEnum)) {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
key.append(signature.getMethod().getName());
Class[] parameterTypes = signature.getParameterTypes();
for (Class clazz : parameterTypes) {
key.append(clazz.getName());
}
key.append(joinPoint.getTarget().getClass());
}
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = requestAttributes.getRequest();
/**
* 第二种:根据用户限流:以用户信息作为key
*/
if (CheckTypeEnum.USER.equals(checkTypeEnum)) {
UserVo userVo = HttpContext.getCurrentUser();
if (request.getUserPrincipal() != null) {
key.append(request.getUserPrincipal().getName());
} else if (userVo != null) {
if (StringUtils.isNotBlank(userVo.getOrganizationCode())) {
key.append(userVo.getOrganizationCode()).append(":");
}
if (StringUtils.isNotBlank(userVo.getSceneCode())) {
key.append(userVo.getSceneCode()).append(":");
}
if (StringUtils.isNotBlank(userVo.getIdentityCode())) {
key.append(userVo.getIdentityCode()).append(":");
}
key.append(userVo.getId());
} else {
throw new LimitRateException(LimitRateErrorEnum.USER_NOT_DOUND);
}
}
/**
* 第三种:根据IP限流:以IP地址作为key
*/
if (CheckTypeEnum.IP.equals(checkTypeEnum)) {
String ip = IpUtils.getRemortIP(request);
if (ip != null) {
key.append(getIpAddr(request));
} else {
throw new LimitRateException(LimitRateErrorEnum.IP_ERROR);
}
}
/**
* 第四种:自定义限流方法:以自定义内容作为key
*/
if (CheckTypeEnum.CUSTOM.equals(checkTypeEnum)) {
if (request.getAttribute(LimitConstants.CUSTOM) != null) {
key.append(request.getAttribute(LimitConstants.CUSTOM).toString());
} else {
throw new LimitRateException(LimitRateErrorEnum.CUSTOM_NOT_DOUND);
}
}
return key.toString();
}
/**
* 获取当前网络ip
*
* @param request HttpServletRequest
* @return ip
*/
public static String getIpAddr(HttpServletRequest request) {
String ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
if (ipAddress.equals("127.0.0.1") || ipAddress.equals("0:0:0:0:0:0:0:1")) {
//根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
log.error("获取IP地址失败", e);
}
if (inet != null) {
ipAddress = inet.getHostAddress();
}
}
}
if (ipAddress != null && ipAddress.length() > 15) {
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
}
}
return ipAddress;
}
}
自定义异常和异常枚举类
自定义限流异常
@Data
public class LimitRateException extends RuntimeException{
private String msg;
private Integer code;
public LimitRateException(LimitRateErrorEnum error) {
super(error.getMsg());
this.msg = error.getMsg();
this.code = error.getCode();
}
public LimitRateException(String msg, Integer code){
super(msg);
this.msg = msg;
this.code = code;
}
}
限流异常通用错误信息 - 枚举类
public enum LimitRateErrorEnum {
TOO_MANY_REQUESTS(50001, "personal-resource-rateLimit say: You have made too many requests,please try again later!!!"),
USER_NOT_DOUND(50002, "personal-resource-rateLimit say: not found user info ,please check request.getUserPrincipal().getName()!!!"),
CUSTOM_NOT_DOUND(50003, "personal-resource-rateLimit say: not found custom info ,please check request.getAttribute('syj-rateLimit-custom')!!!"),
BLACK_REQUESTS(50004, "personal-resource-rateLimit say: You are in black list"),
IP_ERROR(50005, "personal-resource-rateLimit say: ip error");
private final String msg;
private final Integer code;
LimitRateErrorEnum(Integer code, String msg){
this.msg = msg;
this.code = code;
}
public String getMsg() {
return msg;
}
public Integer getCode() {
return code;
}
}
限速算法
先定义一个接口
public interface RateLimiterAlgorithm {
/**
* @param key key
* @param limit 限制次数
* @param refreshInterval 限流时间间隔
*/
void consume(String key, long limit, long refreshInterval);
}
定义一个抽象类
public abstract class RateLimiter {
public void counterConsume(String key, long limit, long lrefreshInterval){
}
public void tokenConsume(String key, long limit, long lrefreshInterval, long tokenBucketStepNum, long tokenBucketTimeInterval){
}
}
接口的实现类
@Service
@DependsOn("rateLimiter")
@RequiredArgsConstructor
@ConditionalOnProperty(prefix = LimitConstants.PREFIX, name = "algorithm", havingValue = "counter", matchIfMissing = true)
public class CounterAlgorithmImpl implements RateLimiterAlgorithm {
@NonNull
private RateLimiter rateLimiter;
public void consume(String key, long limit, long lrefreshInterval){
rateLimiter.counterConsume(key,limit,lrefreshInterval);
}
}
抽象类的子类
@Slf4j
public class RedisRateLimiterCounterImpl extends RateLimiter {
@Autowired
private RedisTemplate redisTemplate;
// lua 脚本
private DefaultRedisScript<Long> redisScript;
// 通过构造器传入 lua
public RedisRateLimiterCounterImpl(DefaultRedisScript redisScript){
this.redisScript = redisScript;
}
public void counterConsume(String key, long limit, long lrefreshInterval) throws LimitRateException {
List<Object> keyList = new ArrayList<>();
keyList.add(key);
log.debug("限流传参:{}", JSON.toJSONString(keyList));
// 参数分别为:redisScript | 参数编解码器 | 返回值编解码器 | key集合 | Object... args (在 lua 脚本中会以数组接收)
String result = redisTemplate.execute(redisScript, new StringRedisSerializer(), new StringRedisSerializer(), keyList, limit+"", lrefreshInterval+"").toString();
if(LimitConstants.REDIS_ERROR.equals(result)){
throw new LimitRateException(LimitRateErrorEnum.TOO_MANY_REQUESTS);
}
}
}
配置类,在容器启动时加载
@Slf4j
@Configuration
public class EnableRateLimitConfiguration {
@Bean
@ConditionalOnMissingBean(name = "redisTemplate")
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory, RedisSerializer fastJson2JsonRedisSerializer){
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
template.setKeySerializer(new StringRedisSerializer());
template.setValueSerializer(fastJson2JsonRedisSerializer);
template.afterPropertiesSet();
return template;
}
@Bean
@ConditionalOnMissingBean(name = "fastJson2JsonRedisSerializer")
public RedisSerializer fastJson2JsonRedisSerializer() {
return new FastJsonRedisSerializer<>(Object.class);
}
@Bean(name = "rateLimiter")
public RateLimiter tokenRateLimiter(){
DefaultRedisScript<Long> consumeRedisScript = new DefaultRedisScript<>();
consumeRedisScript.setResultType(Long.class);
consumeRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/redis-ratelimiter-counter.lua")));
return new RedisRateLimiterCounterImpl(consumeRedisScript);
}
}
lua 脚本
redis-ratelimiter-counter.lua
local key = KEYS[1];
-- 最大限制
local limit = tonumber(ARGV[1]);
-- 过期时间
local expire = tonumber(ARGV[2])
-- 根据key判断是否存在
local hasKey = redis.call('EXISTS', KEYS[1]);
if hasKey == 1 then
-- 根据key获取val(次数)
local value = tonumber(redis.call('GET', KEYS[1]));
if value >= limit then
return -1;
end
end
-- 次数自增
redis.call('INCR', KEYS[1]);
-- 获取ttl
local ttl = redis.call('TTL', KEYS[1]);
if ttl < 0 then
-- 超时
redis.call('EXPIRE', KEYS[1], expire);
end
return 1;