导读
根据项目需求,要添加一个跨域访问判断。
使用
方式一(推荐)
配置CsrFilters过滤器
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.annotation.Order;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
/**
* <p>
* 跨域问题解决,只允许指定的访问
* </p>
*/
@Order(Integer.MAX_VALUE - 2) //设置Filter执行顺序,越大越先执行
@WebFilter(urlPatterns = "/*", filterName = "domainFilter") //设置过滤规则,这里过滤所有请求
public class CsrFilters implements Filter {
/**
*设置运行请求源
*/
@Value("${allow-origin}")
private String domain;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
HttpServletResponse response = (HttpServletResponse) res;
//因为允许多个域访问,所以进行分割判断
String[] allowDomain = domain.split(",");
Set<String> allowedOrigins = new HashSet<String>(Arrays.asList(allowDomain));
//从请求头获取Orgin信息
String originHeader = ((HttpServletRequest) req).getHeader("Origin");
if (allowedOrigins.contains(originHeader)) {
response.setHeader("Access-Control-Allow-Origin", originHeader);
response.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE");
response.setHeader("Access-Control-Max-Age", "3600");
response.setHeader("Access-Control-Allow-Headers", "content-type, x-requested-with");
response.setHeader("Access-Control-Allow-Credentials", "true");
}
chain.doFilter(req, res);
}
@Override
public void destroy() {
}
}
配置文件application.yml
#防止跨站请求安全地址
req_filter_urls: http://localhost,http://127.0.0.1
方式二
过滤器CSRFFilter
import java.io.IOException;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* <p>
* 功能:防止跨站请求伪造过滤器
* </P>
*/
public class CSRFFilter implements Filter {
/**
*登录页面放行
*/
private static final String HOME_URL = "/login.html";
private FilterConfig filterConfig = null;
@Override
public void destroy() {
this.filterConfig = null;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse resp = (HttpServletResponse) response;
// 从http头中获取Referer
String referer = req.getHeader("Referer");
// 系统配置的referer头信息
String myReferer = filterConfig.getInitParameter("referer");
//取消拦截登录界面
if (req.getRequestURI().equals(CSRFFilter.HOME_URL)) {
chain.doFilter(request, response);
}
int count = 0;
if (myReferer != null && referer != null) {
if (myReferer.trim().length() > 0) {
String[] myReferers = myReferer.split(",");
for (String s : myReferers) {
if (!referer.trim().startsWith(s)) {
count++;
} else {
chain.doFilter(request, response);
break;
}
}
if (count == myReferers.length) {
System.out.println("检测到您发送的请求可能为跨站伪造请求:" + HttpServletResponse.SC_BAD_REQUEST);
resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
}
}
} else {
System.out.println("检测到您发送的请求可能为跨站伪造请求:" + HttpServletResponse.SC_BAD_REQUEST);
resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
}
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
this.filterConfig = filterConfig;
}
}
设置过滤FilterConfig
import com.demo.security.filter.CSRFFilter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.DispatcherType;
/**
* Filter配置
*/
@Configuration
public class FilterConfig {
/**
* 设置过滤的地址
*/
@Value("#{'${req_filter_urls}'}")
private String domains;
@Bean
public FilterRegistrationBean csrfFilterRegistration() {
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setDispatcherTypes(DispatcherType.REQUEST);
registration.setFilter(new CSRFFilter());
registration.addUrlPatterns("/*");
registration.setName("csrfFilter");
registration.setEnabled(true);
Map<String, String> initParameters = new HashMap<>();
initParameters.put("referer", domains);
initParameters.put("excludeUrlPattern", "/login.html");
registration.setInitParameters(initParameters);
registration.setOrder(Integer.MAX_VALUE - 2);
return registration;
}
}
配置文件application.yml
#防止跨站请求安全地址
req_filter_urls: http://localhost,http://127.0.0.1
检查是否支持跨域
在浏览器上打开一个网站(非https的网站),比如有道词典等,F12后在console里输入:
(function loadXMLDoc() {
var xmlhttp;
xmlhttp = new XMLHttpRequest();
xmlhttp.onreadystatecatechange = function() {
}
xmlhttp.open("GET", "http://localhost/api/getList?beginTime=1990-01-01", true);
xmlhttp.withCredentials = true;
xmlhttp.send();
})()
如下:
END
参考:链接
搞定~