背景
整个项目设置了统一的返回数据格式,需要修改资源服务器获取认证中心的用户信息后的解析。因为资源服务器和zuul网关是整合为一个服务的,向下游的服务转发时需要携带用户id。所以也需要修改线程保存的用户信息。
spring cloud 认证大致思路
当用户携带token访问zuul网关时经过zuul过滤器,因为需要认证而转发到认证服务器,认证服务使用rabbion组件访问认证中心获取用户信息(调用类为RemoteTokenServices),使用类 AccessTokenConverter(默认实现类 DefaultAccessTokenConverter)解析token,具体解析封装在类 UserAuthenticationConverter(默认实现为 DefaultUserAuthenticationConverter),然后经过认证将用户信息(user_name)放入 SecurityContextHolder.setContext(SecurityContext context) 的 authentication() 实例中,以供程序获取当前用户信息。
修改思路
重写 RemoteTokenServices 、 AccessTokenConverter 和 UserAuthenticationConverter的实现类。将用户id放入SecurityContext中。
MyRemoteTokenServices.java
package com.boya.web.config.oauth;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AccessTokenConverter;
import org.springframework.security.oauth2.provider.token.ResourceServerTokenServices;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.boya.config.response.HttpResponseBase;
//@Primary
//@Service
public class MyRemoteTokenServices implements ResourceServerTokenServices {
protected final Log logger = LogFactory.getLog(getClass());
private RestOperations restTemplate;
@Value("${security.oauth2.resource.token-info-uri}")
private String checkTokenEndpointUrl;
@Value("${security.oauth2.client.client-id}")
private String clientId;
@Value("${security.oauth2.client.client-secret}")
private String clientSecret;
private String tokenName = "token";
//使用自定义的AccessTokenConverter
private AccessTokenConverter tokenConverter = new MyAccessTokenConverter();
public MyRemoteTokenServices() {
restTemplate = new RestTemplate();
((RestTemplate) restTemplate).setErrorHandler(new DefaultResponseErrorHandler() {
@Override
// Ignore 400
public void handleError(ClientHttpResponse response) throws IOException {
if (response.getRawStatusCode() != 400) {
super.handleError(response);
}
}
});
}
public void setRestTemplate(RestOperations restTemplate) {
this.restTemplate = restTemplate;
}
public void setCheckTokenEndpointUrl(String checkTokenEndpointUrl) {
this.checkTokenEndpointUrl = checkTokenEndpointUrl;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public void setClientSecret(String clientSecret) {
this.clientSecret = clientSecret;
}
public void setAccessTokenConverter(AccessTokenConverter accessTokenConverter) {
this.tokenConverter = accessTokenConverter;
}
public void setTokenName(String tokenName) {
this.tokenName = tokenName;
}
@Override
public OAuth2Authentication loadAuthentication(String accessToken)
throws AuthenticationException, InvalidTokenException {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<String, String>();
formData.add(tokenName, accessToken);
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", getAuthorizationHeader(clientId, clientSecret));
Map<String, Object> map = postForMap(checkTokenEndpointUrl, formData, headers);
if (logger.isDebugEnabled()) {
logger.debug("check_token returned result: " + JSON.toJSONString(map));
}
HttpResponseBase httpResponse = JSON.parseObject(JSON.toJSONString(map), HttpResponseBase.class);
JSONObject rsJson = (JSONObject) httpResponse.getResults();
if ("invalid_token".equalsIgnoreCase(httpResponse.getCode())
|| (rsJson != null && rsJson.containsKey("error"))) {
if (logger.isDebugEnabled()) {
logger.debug("check_token returned error: " + map.get("error"));
}
throw new InvalidTokenException(accessToken);
}
if (!Boolean.TRUE.equals(rsJson.get("active"))) {
logger.debug("check_token returned active attribute: " + map.get("active"));
throw new InvalidTokenException(accessToken);
}
return tokenConverter.extractAuthentication(rsJson);
}
@Override
public OAuth2AccessToken readAccessToken(String accessToken) {
throw new UnsupportedOperationException("Not supported: read access token");
}
private String getAuthorizationHeader(String clientId, String clientSecret) {
if (clientId == null || clientSecret == null) {
logger.warn(
"Null Client ID or Client Secret detected. Endpoint that requires authentication will reject request with 401 error.");
}
String creds = String.format("%s:%s", clientId, clientSecret);
try {
return "Basic " + new String(java.util.Base64.getEncoder().encode(creds.getBytes("UTF-8")));
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException("Could not convert String");
}
}
private Map<String, Object> postForMap(String path, MultiValueMap<String, String> formData, HttpHeaders headers) {
if (headers.getContentType() == null) {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
}
@SuppressWarnings("rawtypes")
Map map = restTemplate.exchange(path, HttpMethod.POST,
new HttpEntity<MultiValueMap<String, String>>(formData, headers), Map.class).getBody();
@SuppressWarnings("unchecked")
Map<String, Object> result = map;
return result;
}
}
MyAccessTokenConverter.java
package com.boya.web.config.oauth;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.token.AccessTokenConverter;
import org.springframework.security.oauth2.provider.token.UserAuthenticationConverter;
/**
* Default implementation of {@link AccessTokenConverter}.
*
* @author Dave Syer
* @author Vedran Pavic
*/
public class MyAccessTokenConverter implements AccessTokenConverter {
//使用自定义UserAuthenticationConverter
private UserAuthenticationConverter userTokenConverter = new MyUserAuthenticationConverter();
private boolean includeGrantType;
private String scopeAttribute = SCOPE;
private String clientIdAttribute = CLIENT_ID;
/**
* Converter for the part of the data in the token representing a user.
*
* @param userTokenConverter the userTokenConverter to set
*/
public void setUserTokenConverter(UserAuthenticationConverter userTokenConverter) {
this.userTokenConverter = userTokenConverter;
}
/**
* Flag to indicate the the grant type should be included in the converted token.
*
* @param includeGrantType the flag value (default false)
*/
public void setIncludeGrantType(boolean includeGrantType) {
this.includeGrantType = includeGrantType;
}
/**
* Set scope attribute name to be used in the converted token. Defaults to
* {@link AccessTokenConverter#SCOPE}.
*
* @param scopeAttribute the scope attribute name to use
*/
public void setScopeAttribute(String scopeAttribute) {
this.scopeAttribute = scopeAttribute;
}
/**
* Set client id attribute name to be used in the converted token. Defaults to
* {@link AccessTokenConverter#CLIENT_ID}.
*
* @param clientIdAttribute the client id attribute name to use
*/
public void setClientIdAttribute(String clientIdAttribute) {
this.clientIdAttribute = clientIdAttribute;
}
public Map<String, ?> convertAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
Map<String, Object> response = new HashMap<String, Object>();
OAuth2Request clientToken = authentication.getOAuth2Request();
if (!authentication.isClientOnly()) {
response.putAll(userTokenConverter.convertUserAuthentication(authentication.getUserAuthentication()));
} else {
if (clientToken.getAuthorities()!=null && !clientToken.getAuthorities().isEmpty()) {
response.put(UserAuthenticationConverter.AUTHORITIES,
AuthorityUtils.authorityListToSet(clientToken.getAuthorities()));
}
}
if (token.getScope()!=null) {
response.put(scopeAttribute, token.getScope());
}
if (token.getAdditionalInformation().containsKey(JTI)) {
response.put(JTI, token.getAdditionalInformation().get(JTI));
}
if (token.getExpiration() != null) {
response.put(EXP, token.getExpiration().getTime() / 1000);
}
if (includeGrantType && authentication.getOAuth2Request().getGrantType()!=null) {
response.put(GRANT_TYPE, authentication.getOAuth2Request().getGrantType());
}
response.putAll(token.getAdditionalInformation());
response.put(clientIdAttribute, clientToken.getClientId());
if (clientToken.getResourceIds() != null && !clientToken.getResourceIds().isEmpty()) {
response.put(AUD, clientToken.getResourceIds());
}
return response;
}
public OAuth2AccessToken extractAccessToken(String value, Map<String, ?> map) {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(value);
Map<String, Object> info = new HashMap<String, Object>(map);
info.remove(EXP);
info.remove(AUD);
info.remove(clientIdAttribute);
info.remove(scopeAttribute);
if (map.containsKey(EXP)) {
token.setExpiration(new Date((Long) map.get(EXP) * 1000L));
}
if (map.containsKey(JTI)) {
info.put(JTI, map.get(JTI));
}
token.setScope(extractScope(map));
token.setAdditionalInformation(info);
return token;
}
public OAuth2Authentication extractAuthentication(Map<String, ?> map) {
Map<String, String> parameters = new HashMap<String, String>();
Set<String> scope = extractScope(map);
Authentication user = userTokenConverter.extractAuthentication(map);
String clientId = (String) map.get(clientIdAttribute);
parameters.put(clientIdAttribute, clientId);
if (includeGrantType && map.containsKey(GRANT_TYPE)) {
parameters.put(GRANT_TYPE, (String) map.get(GRANT_TYPE));
}
Set<String> resourceIds = new LinkedHashSet<String>(map.containsKey(AUD) ? getAudience(map)
: Collections.<String>emptySet());
Collection<? extends GrantedAuthority> authorities = null;
if (user==null && map.containsKey(AUTHORITIES)) {
@SuppressWarnings("unchecked")
String[] roles = ((Collection<String>)map.get(AUTHORITIES)).toArray(new String[0]);
authorities = AuthorityUtils.createAuthorityList(roles);
}
OAuth2Request request = new OAuth2Request(parameters, clientId, authorities, true, scope, resourceIds, null, null,
null);
return new OAuth2Authentication(request, user);
}
private Collection<String> getAudience(Map<String, ?> map) {
Object auds = map.get(AUD);
if (auds instanceof Collection) {
@SuppressWarnings("unchecked")
Collection<String> result = (Collection<String>) auds;
return result;
}
return Collections.singleton((String)auds);
}
private Set<String> extractScope(Map<String, ?> map) {
Set<String> scope = Collections.emptySet();
if (map.containsKey(scopeAttribute)) {
Object scopeObj = map.get(scopeAttribute);
if (String.class.isInstance(scopeObj)) {
scope = new LinkedHashSet<String>(Arrays.asList(String.class.cast(scopeObj).split(" ")));
} else if (Collection.class.isAssignableFrom(scopeObj.getClass())) {
@SuppressWarnings("unchecked")
Collection<String> scopeColl = (Collection<String>) scopeObj;
scope = new LinkedHashSet<String>(scopeColl); // Preserve ordering
}
}
return scope;
}
}
MyUserAuthenticationConverter.java
package com.boya.web.config.oauth;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.oauth2.provider.token.UserAuthenticationConverter;
import org.springframework.util.StringUtils;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* Default implementation of {@link UserAuthenticationConverter}. Converts to and from an Authentication using only its
* name and authorities.
*
* @author Dave Syer
*
*/
public class MyUserAuthenticationConverter implements UserAuthenticationConverter {
private Collection<? extends GrantedAuthority> defaultAuthorities;
private UserDetailsService userDetailsService;
//新增获取参数
private String USERID = "userId";
/**
* Optional {@link UserDetailsService} to use when extracting an {@link Authentication} from the incoming map.
*
* @param userDetailsService the userDetailsService to set
*/
public void setUserDetailsService(UserDetailsService userDetailsService) {
this.userDetailsService = userDetailsService;
}
/**
* Default value for authorities if an Authentication is being created and the input has no data for authorities.
* Note that unless this property is set, the default Authentication created by {@link #extractAuthentication(Map)}
* will be unauthenticated.
*
* @param defaultAuthorities the defaultAuthorities to set. Default null.
*/
public void setDefaultAuthorities(String[] defaultAuthorities) {
this.defaultAuthorities = AuthorityUtils.commaSeparatedStringToAuthorityList(StringUtils
.arrayToCommaDelimitedString(defaultAuthorities));
}
public Map<String, ?> convertUserAuthentication(Authentication authentication) {
Map<String, Object> response = new LinkedHashMap<String, Object>();
response.put(USERNAME, authentication.getName());
if (authentication.getAuthorities() != null && !authentication.getAuthorities().isEmpty()) {
response.put(AUTHORITIES, AuthorityUtils.authorityListToSet(authentication.getAuthorities()));
}
return response;
}
public Authentication extractAuthentication(Map<String, ?> map) {
//新增先获取userId
if (map.containsKey(USERID)) {
Object principal = map.get(USERID);
Collection<? extends GrantedAuthority> authorities = getAuthorities(map);
if (userDetailsService != null) {
UserDetails user = userDetailsService.loadUserByUsername((String) map.get(USERID));
authorities = user.getAuthorities();
principal = user;
}
return new UsernamePasswordAuthenticationToken(principal, "N/A", authorities);
}
if (map.containsKey(USERNAME)) {
Object principal = map.get(USERNAME);
Collection<? extends GrantedAuthority> authorities = getAuthorities(map);
if (userDetailsService != null) {
UserDetails user = userDetailsService.loadUserByUsername((String) map.get(USERNAME));
authorities = user.getAuthorities();
principal = user;
}
return new UsernamePasswordAuthenticationToken(principal, "N/A", authorities);
}
return null;
}
private Collection<? extends GrantedAuthority> getAuthorities(Map<String, ?> map) {
if (!map.containsKey(AUTHORITIES)) {
return defaultAuthorities;
}
Object authorities = map.get(AUTHORITIES);
if (authorities instanceof String) {
return AuthorityUtils.commaSeparatedStringToAuthorityList((String) authorities);
}
if (authorities instanceof Collection) {
return AuthorityUtils.commaSeparatedStringToAuthorityList(StringUtils
.collectionToCommaDelimitedString((Collection<?>) authorities));
}
throw new IllegalArgumentException("Authorities must be either a String or a Collection");
}
}
zuul网关下游服务调用时修改
ModifyParamFilter.java
package com.boya.web.config;
import com.alibaba.fastjson.JSON;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import javax.servlet.http.HttpServletRequest;
import java.util.*;
@Component
@Slf4j
public class ModifyHeaderFilter extends ZuulFilter {
//定义filter的类型,有pre、route、post、error四种
@Override
public String filterType() {
return "route";
}
//定义filter的顺序,数字越小表示顺序越高,越先执行
@Override
public int filterOrder() {
return 6;
}
//表示是否需要执行该filter,true表示执行,false表示不执行
@Override
public boolean shouldFilter() {
return true;
}
public Object run() {
System.out.println("------------------------------------------------OAuth2AccessToken---");
SecurityContext context = SecurityContextHolder.getContext();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
String currentPrincipalName = authentication.getName();
System.out.println("当前用户-------"+currentPrincipalName);
RequestContext ctx = RequestContext.getCurrentContext();
HttpServletRequest request = ctx.getRequest();
log.info(String.format("%s >>> %s", request.getMethod(), request.getRequestURL().toString()));
Enumeration<String> headerNames = request.getHeaderNames();
Map<String, Object> headMap = new HashMap<>();
String pa = null;
while ((pa = headerNames.nextElement()) != null) {
headMap.put(pa, request.getHeader(pa));
}
log.info(String.format("header 参数 >>> %s", JSON.toJSONString(headMap)));
Map<String, String[]> requestMap = request.getParameterMap();
log.info(String.format("request 参数 >>> %s", JSON.toJSONString(requestMap)));
Map<String, List<String>> requestQueryParams = ctx.getRequestQueryParams();
if (requestQueryParams==null) {
requestQueryParams=new HashMap<>();
}
//将要新增的参数添加进去,被调用的微服务可以直接 去取,就想普通的一样,框架会直接注入进去
ArrayList<String> arrayList = new ArrayList<>();
arrayList.add(currentPrincipalName);
requestQueryParams.put("userId", arrayList);
// 这个put("requestQueryParams", qp); 是在源码中 会在转发的时候去取这个key里面的参数值.
ctx.setRequestQueryParams(requestQueryParams);
return null;
}
}