导读


由于项目需要,这里需要设计一个IP白名单,增加接口安全。SpringBoot项目,因为接口已经开发好,在不变动接口的情况下,采用拦截器做处理,思路是拦截指定的请求,获取到该请求的IP,然后根据该ip去配置文件查询是否存在,如果存在就允许访问,否则没有权限访问。

使用


创建配置文件ip_address.yml

  1. address:
  2. ipFilters:
  3. - ip: 127.0.0.1
  4. whitelists: true
  5. - ip: 192.168.255.255
  6. whitelists: true

创建读取配置YML文件属性PropConfig

  1. import org.springframework.beans.factory.config.YamlPropertiesFactoryBean;
  2. import org.springframework.context.annotation.Bean;
  3. import org.springframework.context.annotation.Configuration;
  4. import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
  5. import org.springframework.core.io.ClassPathResource;
  6. import java.util.Objects;
  7. /**
  8. * 读取配置文件
  9. */
  10. @Configuration
  11. public class PropConfig {
  12. @Bean
  13. public static PropertySourcesPlaceholderConfigurer properties() {
  14. PropertySourcesPlaceholderConfigurer configurer = new PropertySourcesPlaceholderConfigurer();
  15. YamlPropertiesFactoryBean yaml = new YamlPropertiesFactoryBean();
  16. yaml.setResources(new ClassPathResource("ip_address.yml"));
  17. configurer.setProperties(Objects.requireNonNull(yaml.getObject()));
  18. return configurer;
  19. }
  20. }

创建实体类和IPAddress类

  • 实体类IpFilter ```java import java.io.Serializable;

/**

  • ip过滤器实体类 */ public class IpFilter implements Serializable {

    private static final long serialVersionUID = 8802493743077425037L; /**

    • ip地址 */ private String ip;

      /**

    • 白名单—true,黑名单—false */ private Boolean whitelists;

      public String getIp() { return ip; }

      public void setIp(String ip) { this.ip = ip; }

      public Boolean getWhitelists() { return whitelists; }

      public void setWhitelists(Boolean whitelists) { this.whitelists = whitelists; } }

  1. - **IpAddress类**
  2. ```java
  3. import org.springframework.boot.context.properties.ConfigurationProperties;
  4. import org.springframework.context.annotation.Configuration;
  5. import java.io.Serializable;
  6. import java.util.List;
  7. /**
  8. * IpAddress 存放集合属性
  9. */
  10. @Configuration
  11. @ConfigurationProperties("address")
  12. @Data
  13. public class IpAddress implements Serializable {
  14. private static final long serialVersionUID = -1686798098991604714L;
  15. private List<IpFilter> ipFilters; //ipFilters与yml文件的集合属性对应
  16. }

获取工具类IPUtils

  1. import javax.servlet.http.HttpServletRequest;
  2. /**
  3. * 获取用户的IP
  4. */
  5. public class IPUtils {
  6. /**
  7. * 本地IP localhost
  8. */
  9. private static final String NATIVEIP = "0:0:0:0:0:0:0:1";
  10. /**
  11. * 获取用户真实IP地址,不使用request.getRemoteAddr()的原因是有可能用户使用了代理软件方式避免真实IP地址,
  12. * 可是,如果通过了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP值
  13. *
  14. * @return ip
  15. */
  16. public static String getRealIP(HttpServletRequest request) {
  17. String ip = request.getHeader("x-forwarded-for");
  18. if (ip != null && ip.length() != 0 && !"unknown".equalsIgnoreCase(ip)) {
  19. // 多次反向代理后会有多个ip值,第一个ip才是真实ip
  20. if (ip.indexOf(",") != -1) {
  21. ip = ip.split(",")[0];
  22. }
  23. }
  24. if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
  25. ip = request.getHeader("Proxy-Client-IP");
  26. System.out.println("Proxy-Client-IP ip: " + ip);
  27. }
  28. if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
  29. ip = request.getHeader("WL-Proxy-Client-IP");
  30. System.out.println("WL-Proxy-Client-IP ip: " + ip);
  31. }
  32. if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
  33. ip = request.getHeader("HTTP_CLIENT_IP");
  34. System.out.println("HTTP_CLIENT_IP ip: " + ip);
  35. }
  36. if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
  37. ip = request.getHeader("HTTP_X_FORWARDED_FOR");
  38. System.out.println("HTTP_X_FORWARDED_FOR ip: " + ip);
  39. }
  40. if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
  41. ip = request.getHeader("X-Real-IP");
  42. System.out.println("X-Real-IP ip: " + ip);
  43. }
  44. if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
  45. ip = request.getRemoteAddr();
  46. System.out.println("getRemoteAddr ip: " + ip);
  47. }
  48. if (ip.equals(NATIVEIP)) {
  49. ip = "127.0.0.1";
  50. System.out.println("get native ip" + ip);
  51. }
  52. return ip;
  53. }
  54. }

拦截器IPInterceptor

  1. import com.demo.common.entity.IpAddress;
  2. import com.demo.common.entity.IpFilter;
  3. import com.demo.common.utils.IPUtils;
  4. import org.apache.commons.lang3.StringUtils;
  5. import org.slf4j.Logger;
  6. import org.slf4j.LoggerFactory;
  7. import org.springframework.beans.factory.annotation.Autowired;
  8. import org.springframework.stereotype.Component;
  9. import org.springframework.web.servlet.HandlerInterceptor;
  10. import org.springframework.web.servlet.ModelAndView;
  11. import javax.servlet.http.HttpServletRequest;
  12. import javax.servlet.http.HttpServletResponse;
  13. import java.util.List;
  14. /**
  15. * IP 拦截器
  16. *
  17. * @author hyanchao
  18. * @create 2020/1/14 14:23
  19. */
  20. @Component
  21. public class IPInterceptor implements HandlerInterceptor {
  22. private static Logger LOG = LoggerFactory.getLogger(IPInterceptor.class);
  23. @Autowired
  24. private IpAddress address;
  25. @Override
  26. public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
  27. //过滤ip,若用户在白名单内,则放行
  28. String ipAddress = IPUtils.getRealIP(request);
  29. LOG.info("USER IP ADDRESS IS =>" + ipAddress);
  30. if (!StringUtils.isNotBlank(ipAddress)) {
  31. response.getWriter().append("<h1 style=\"text-align:center;\">Not allowed!</h1>");
  32. return false;
  33. }
  34. //读取ip地址白名单集合,如果存在,并且为白名单,则放行,不然就拦截
  35. List<IpFilter> ipFilters = address.getIpFilters();
  36. if (ipFilters != null) {
  37. for (IpFilter ips : ipFilters) {
  38. if (ipAddress.equals(ips.getIp()) && ips.getWhitelists()) {
  39. return true;
  40. }
  41. }
  42. }
  43. response.getWriter().append("<h1 style=\"text-align:center;\">Not allowed!</h1>");
  44. return false;
  45. }
  46. @Override
  47. public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
  48. }
  49. @Override
  50. public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
  51. }
  52. }

全局配置拦截

  1. import com.demo.common.interceptor.IPInterceptor;
  2. import org.springframework.beans.factory.annotation.Autowired;
  3. import org.springframework.context.annotation.Configuration;
  4. import org.springframework.web.servlet.config.annotation.CorsRegistry;
  5. import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
  6. import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
  7. @Configuration
  8. public class GlobalCorsConfig extends WebMvcConfigurerAdapter {
  9. @Autowired
  10. private IPInterceptor ipInterceptor;
  11. @Override
  12. public void addInterceptors(InterceptorRegistry registry) {
  13. // 添加拦截器,配置拦截地址,这里拦截以api请求的接口比如http://localhost/api/getUser
  14. registry.addInterceptor(ipInterceptor).addPathPatterns("/api/**");
  15. }
  16. }

END


这里还没有做限流处理,下次考虑试下限流。