Polish gh-77
This commit is contained in:
parent
54e219a397
commit
fbc98d511c
@ -15,6 +15,7 @@
|
||||
*/
|
||||
package org.springframework.security.oauth2.server.authorization;
|
||||
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.List;
|
||||
@ -65,7 +66,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
|
||||
|
||||
private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) {
|
||||
if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
|
||||
return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue()));
|
||||
return token.equals(authorization.getAttributes().get(OAuth2ParameterNames.class.getName().concat(".CODE")));
|
||||
} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
|
||||
return authorization.getAccessToken() != null &&
|
||||
authorization.getAccessToken().getTokenValue().equals(token);
|
||||
|
@ -16,6 +16,7 @@
|
||||
package org.springframework.security.oauth2.server.authorization;
|
||||
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
@ -196,7 +197,7 @@ public class OAuth2Authorization implements Serializable {
|
||||
*/
|
||||
public OAuth2Authorization build() {
|
||||
Assert.hasText(this.principalName, "principalName cannot be empty");
|
||||
Assert.notNull(this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()), "authorization code cannot be null");
|
||||
Assert.notNull(this.attributes.get(OAuth2ParameterNames.class.getName().concat(".CODE")), "authorization code cannot be null");
|
||||
|
||||
OAuth2Authorization authorization = new OAuth2Authorization();
|
||||
authorization.registeredClientId = this.registeredClientId;
|
||||
|
@ -15,29 +15,21 @@
|
||||
*/
|
||||
package org.springframework.security.oauth2.server.authorization.web;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
|
||||
import org.springframework.security.crypto.keygen.StringKeyGenerator;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||
import org.springframework.security.oauth2.server.authorization.TokenType;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||
import org.springframework.security.web.DefaultRedirectStrategy;
|
||||
@ -45,201 +37,257 @@ import org.springframework.security.web.RedirectStrategy;
|
||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.filter.OncePerRequestFilter;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
|
||||
* which handles the processing of the OAuth 2.0 Authorization Request.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @author Paurav Munshi
|
||||
* @since 0.0.1
|
||||
* @see RegisteredClientRepository
|
||||
* @see OAuth2AuthorizationService
|
||||
* @see OAuth2Authorization
|
||||
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
|
||||
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Request</a>
|
||||
*/
|
||||
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
|
||||
/**
|
||||
* The default endpoint {@code URI} for authorization requests.
|
||||
*/
|
||||
public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
|
||||
|
||||
private static final String DEFAULT_ENDPOINT = "/oauth2/authorize";
|
||||
|
||||
private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter = new OAuth2AuthorizationRequestConverter();
|
||||
private RegisteredClientRepository registeredClientRepository;
|
||||
private OAuth2AuthorizationService authorizationService;
|
||||
private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator();
|
||||
private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
|
||||
private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT);
|
||||
private final RegisteredClientRepository registeredClientRepository;
|
||||
private final OAuth2AuthorizationService authorizationService;
|
||||
private final RequestMatcher authorizationEndpointMatcher;
|
||||
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
|
||||
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
|
||||
|
||||
/**
|
||||
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
|
||||
*
|
||||
* @param registeredClientRepository the repository of registered clients
|
||||
* @param authorizationService the authorization service
|
||||
*/
|
||||
public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
|
||||
OAuth2AuthorizationService authorizationService) {
|
||||
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null.");
|
||||
Assert.notNull(authorizationService, "authorizationService cannot be null.");
|
||||
this(registeredClientRepository, authorizationService, DEFAULT_AUTHORIZATION_ENDPOINT_URI);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
|
||||
*
|
||||
* @param registeredClientRepository the repository of registered clients
|
||||
* @param authorizationService the authorization service
|
||||
* @param authorizationEndpointUri the endpoint {@code URI} for authorization requests
|
||||
*/
|
||||
public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
|
||||
OAuth2AuthorizationService authorizationService, String authorizationEndpointUri) {
|
||||
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
|
||||
Assert.notNull(authorizationService, "authorizationService cannot be null");
|
||||
Assert.hasText(authorizationEndpointUri, "authorizationEndpointUri cannot be empty");
|
||||
this.registeredClientRepository = registeredClientRepository;
|
||||
this.authorizationService = authorizationService;
|
||||
}
|
||||
|
||||
public final void setAuthorizationRequestConverter(
|
||||
Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter) {
|
||||
Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null");
|
||||
this.authorizationRequestConverter = authorizationRequestConverter;
|
||||
}
|
||||
|
||||
public final void setCodeGenerator(StringKeyGenerator codeGenerator) {
|
||||
Assert.notNull(codeGenerator, "codeGenerator cannot be set to null");
|
||||
this.codeGenerator = codeGenerator;
|
||||
}
|
||||
|
||||
public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
|
||||
Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null");
|
||||
this.authorizationRedirectStrategy = authorizationRedirectStrategy;
|
||||
}
|
||||
|
||||
public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) {
|
||||
Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null");
|
||||
this.authorizationEndpointMatcher = authorizationEndpointMatcher;
|
||||
this.authorizationEndpointMatcher = new AntPathRequestMatcher(
|
||||
authorizationEndpointUri, HttpMethod.GET.name());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
|
||||
boolean pathMatch = this.authorizationEndpointMatcher.matches(request);
|
||||
String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
|
||||
boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType);
|
||||
if (pathMatch && responseTypeMatch) {
|
||||
return false;
|
||||
}else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doFilterInternal(HttpServletRequest request,
|
||||
HttpServletResponse response, FilterChain filterChain)
|
||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||
throws ServletException, IOException {
|
||||
|
||||
RegisteredClient client = null;
|
||||
OAuth2AuthorizationRequest authorizationRequest = null;
|
||||
OAuth2Authorization authorization = null;
|
||||
if (!this.authorizationEndpointMatcher.matches(request) || !isPrincipalAuthenticated()) {
|
||||
filterChain.doFilter(request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
checkUserAuthenticated();
|
||||
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
|
||||
client = fetchRegisteredClient(request);
|
||||
// TODO
|
||||
// The authorization server validates the request to ensure that all
|
||||
// required parameters are present and valid. If the request is valid,
|
||||
// the authorization server authenticates the resource owner and obtains
|
||||
// an authorization decision (by asking the resource owner or by
|
||||
// establishing approval via other means).
|
||||
|
||||
authorizationRequest = this.authorizationRequestConverter.convert(request);
|
||||
validateAuthorizationRequest(authorizationRequest, client);
|
||||
MultiValueMap<String, String> parameters = getParameters(request);
|
||||
String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
|
||||
|
||||
// client_id (REQUIRED)
|
||||
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
|
||||
if (!StringUtils.hasText(clientId) ||
|
||||
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
|
||||
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
|
||||
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
|
||||
return;
|
||||
}
|
||||
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
|
||||
if (registeredClient == null) {
|
||||
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
|
||||
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
|
||||
return;
|
||||
} else if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
|
||||
OAuth2Error error = createError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID);
|
||||
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
|
||||
return;
|
||||
}
|
||||
|
||||
// redirect_uri (OPTIONAL)
|
||||
String redirectUriParameter = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
|
||||
if (StringUtils.hasText(redirectUriParameter)) {
|
||||
if (!registeredClient.getRedirectUris().contains(redirectUriParameter) ||
|
||||
parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
|
||||
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
|
||||
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
|
||||
return;
|
||||
}
|
||||
} else if (registeredClient.getRedirectUris().size() != 1) {
|
||||
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
|
||||
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
|
||||
return;
|
||||
}
|
||||
|
||||
String redirectUri = StringUtils.hasText(redirectUriParameter) ?
|
||||
redirectUriParameter : registeredClient.getRedirectUris().iterator().next();
|
||||
|
||||
// response_type (REQUIRED)
|
||||
String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);
|
||||
if (!StringUtils.hasText(responseType) ||
|
||||
parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) {
|
||||
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE);
|
||||
sendErrorResponse(request, response, error, stateParameter, redirectUri);
|
||||
return;
|
||||
} else if (!responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) {
|
||||
OAuth2Error error = createError(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2ParameterNames.RESPONSE_TYPE);
|
||||
sendErrorResponse(request, response, error, stateParameter, redirectUri);
|
||||
return;
|
||||
}
|
||||
|
||||
Authentication principal = SecurityContextHolder.getContext().getAuthentication();
|
||||
String code = this.codeGenerator.generateKey();
|
||||
authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code);
|
||||
this.authorizationService.save(authorization);
|
||||
OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest(request);
|
||||
|
||||
String redirectUri = getRedirectUri(authorizationRequest, client);
|
||||
sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code);
|
||||
}
|
||||
catch(OAuth2AuthorizationException authorizationException) {
|
||||
OAuth2Error authorizationError = authorizationException.getError();
|
||||
|
||||
if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
|
||||
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
|
||||
sendErrorInResponse(response, authorizationError);
|
||||
}
|
||||
else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)
|
||||
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) {
|
||||
String redirectUri = getRedirectUri(authorizationRequest, client);
|
||||
sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri);
|
||||
}
|
||||
else {
|
||||
throw new ServletException(authorizationException);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void checkUserAuthenticated() {
|
||||
Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication();
|
||||
if (currentAuth==null || !currentAuth.isAuthenticated()) {
|
||||
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
|
||||
}
|
||||
}
|
||||
|
||||
private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException {
|
||||
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
|
||||
if (StringUtils.isEmpty(clientId)) {
|
||||
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
|
||||
}
|
||||
|
||||
RegisteredClient client = this.registeredClientRepository.findByClientId(clientId);
|
||||
if (client==null) {
|
||||
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
|
||||
}
|
||||
|
||||
boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes())
|
||||
.anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE));
|
||||
if (!isAuthorizationGrantAllowed) {
|
||||
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
|
||||
}
|
||||
|
||||
return client;
|
||||
|
||||
}
|
||||
|
||||
private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client,
|
||||
OAuth2AuthorizationRequest authorizationRequest, String code) {
|
||||
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client)
|
||||
.principalName(auth.getPrincipal().toString())
|
||||
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), code)
|
||||
.attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes()))
|
||||
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
|
||||
.principalName(principal.getName())
|
||||
.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), code)
|
||||
.attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
|
||||
.build();
|
||||
|
||||
return authorization;
|
||||
this.authorizationService.save(authorization);
|
||||
|
||||
// TODO security checks for code parameter
|
||||
// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks.
|
||||
// A maximum authorization code lifetime of 10 minutes is RECOMMENDED.
|
||||
// The client MUST NOT use the authorization code more than once.
|
||||
// If an authorization code is used more than once, the authorization server MUST deny the request
|
||||
// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code.
|
||||
// The authorization code is bound to the client identifier and redirection URI.
|
||||
|
||||
sendAuthorizationResponse(request, response, authorizationRequest, code, redirectUri);
|
||||
}
|
||||
|
||||
private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response,
|
||||
OAuth2AuthorizationRequest authorizationRequest, String code, String redirectUri) throws IOException {
|
||||
|
||||
private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
|
||||
String redirectUri = authorizationRequest.getRedirectUri();
|
||||
if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) {
|
||||
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
|
||||
}
|
||||
if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) {
|
||||
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
|
||||
return !StringUtils.isEmpty(authorizationRequest.getRedirectUri())
|
||||
? authorizationRequest.getRedirectUri()
|
||||
: client.getRedirectUris().stream().findFirst().get();
|
||||
}
|
||||
|
||||
private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response,
|
||||
OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException {
|
||||
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
|
||||
UriComponentsBuilder uriBuilder = UriComponentsBuilder
|
||||
.fromUriString(redirectUri)
|
||||
.queryParam(OAuth2ParameterNames.CODE, code);
|
||||
if (!StringUtils.isEmpty(authorizationRequest.getState())) {
|
||||
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
|
||||
if (StringUtils.hasText(authorizationRequest.getState())) {
|
||||
uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
|
||||
}
|
||||
this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
|
||||
}
|
||||
|
||||
String finalRedirectUri = redirectUriBuilder.toUriString();
|
||||
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri);
|
||||
private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
|
||||
OAuth2Error error, String state, String redirectUri) throws IOException {
|
||||
|
||||
if (redirectUri == null) {
|
||||
// TODO Send default html error response
|
||||
response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString());
|
||||
return;
|
||||
}
|
||||
|
||||
private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException {
|
||||
int errorStatus = -1;
|
||||
String errorCode = authorizationError.getErrorCode();
|
||||
if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
|
||||
errorStatus=HttpStatus.FORBIDDEN.value();
|
||||
UriComponentsBuilder uriBuilder = UriComponentsBuilder
|
||||
.fromUriString(redirectUri)
|
||||
.queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
|
||||
if (StringUtils.hasText(error.getDescription())) {
|
||||
uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
|
||||
}
|
||||
else {
|
||||
errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value();
|
||||
if (StringUtils.hasText(error.getUri())) {
|
||||
uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
|
||||
}
|
||||
response.sendError(errorStatus, authorizationError.getErrorCode());
|
||||
if (StringUtils.hasText(state)) {
|
||||
uriBuilder.queryParam(OAuth2ParameterNames.STATE, state);
|
||||
}
|
||||
this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
|
||||
}
|
||||
|
||||
private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response,
|
||||
OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError,
|
||||
String redirectUri) throws IOException {
|
||||
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
|
||||
.queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode());
|
||||
|
||||
if (!StringUtils.isEmpty(authorizationRequest.getState())) {
|
||||
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
|
||||
private static boolean isPrincipalAuthenticated() {
|
||||
return isPrincipalAuthenticated(SecurityContextHolder.getContext().getAuthentication());
|
||||
}
|
||||
|
||||
String finalRedirectURI = redirectUriBuilder.toUriString();
|
||||
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI);
|
||||
private static boolean isPrincipalAuthenticated(Authentication principal) {
|
||||
return principal != null &&
|
||||
!AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) &&
|
||||
principal.isAuthenticated();
|
||||
}
|
||||
|
||||
private static OAuth2Error createError(String errorCode, String parameterName) {
|
||||
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
|
||||
"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
|
||||
}
|
||||
|
||||
private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
|
||||
MultiValueMap<String, String> parameters = getParameters(request);
|
||||
|
||||
Set<String> scopes = Collections.emptySet();
|
||||
if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) {
|
||||
String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
|
||||
scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
|
||||
}
|
||||
|
||||
return OAuth2AuthorizationRequest.authorizationCode()
|
||||
.authorizationUri(request.getRequestURL().toString())
|
||||
.clientId(parameters.getFirst(OAuth2ParameterNames.CLIENT_ID))
|
||||
.redirectUri(parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI))
|
||||
.scopes(scopes)
|
||||
.state(parameters.getFirst(OAuth2ParameterNames.STATE))
|
||||
.additionalParameters(additionalParameters ->
|
||||
parameters.entrySet().stream()
|
||||
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.RESPONSE_TYPE) &&
|
||||
!e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
|
||||
!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI) &&
|
||||
!e.getKey().equals(OAuth2ParameterNames.SCOPE) &&
|
||||
!e.getKey().equals(OAuth2ParameterNames.STATE))
|
||||
.forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0))))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
|
||||
Map<String, String[]> parameterMap = request.getParameterMap();
|
||||
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
|
||||
parameterMap.forEach((key, values) -> {
|
||||
if (values.length > 0) {
|
||||
for (String value : values) {
|
||||
parameters.add(key, value);
|
||||
}
|
||||
}
|
||||
});
|
||||
return parameters;
|
||||
}
|
||||
}
|
||||
|
@ -1,55 +0,0 @@
|
||||
/*
|
||||
* Copyright 2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.oauth2.server.authorization.web;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* @author Paurav Munshi
|
||||
* @since 0.0.1
|
||||
* @see Converter
|
||||
*/
|
||||
public class OAuth2AuthorizationRequestConverter implements Converter<HttpServletRequest, OAuth2AuthorizationRequest> {
|
||||
|
||||
@Override
|
||||
public OAuth2AuthorizationRequest convert(HttpServletRequest request) {
|
||||
String scope = request.getParameter(OAuth2ParameterNames.SCOPE);
|
||||
Set<String> scopes = !StringUtils.isEmpty(scope)
|
||||
? new LinkedHashSet<String>(Arrays.asList(scope.split(" ")))
|
||||
: Collections.emptySet();
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
|
||||
.clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID))
|
||||
.redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
|
||||
.scopes(scopes)
|
||||
.state(request.getParameter(OAuth2ParameterNames.STATE))
|
||||
.authorizationUri(request.getServletPath())
|
||||
.build();
|
||||
|
||||
return authorizationRequest;
|
||||
}
|
||||
|
||||
}
|
@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
||||
|
||||
@ -61,7 +62,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
|
||||
public void saveWhenAuthorizationProvidedThenSaved() {
|
||||
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
|
||||
.principalName(PRINCIPAL_NAME)
|
||||
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
|
||||
.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
|
||||
.build();
|
||||
this.authorizationService.save(expectedAuthorization);
|
||||
|
||||
@ -88,7 +89,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
|
||||
public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() {
|
||||
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
|
||||
.principalName(PRINCIPAL_NAME)
|
||||
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
|
||||
.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
|
||||
.build();
|
||||
this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
|
||||
|
||||
@ -103,7 +104,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
|
||||
"access-token", Instant.now().minusSeconds(60), Instant.now());
|
||||
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
|
||||
.principalName(PRINCIPAL_NAME)
|
||||
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
|
||||
.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
|
||||
.accessToken(accessToken)
|
||||
.build();
|
||||
this.authorizationService.save(authorization);
|
||||
|
@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
||||
|
||||
@ -84,13 +85,13 @@ public class OAuth2AuthorizationTests {
|
||||
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
|
||||
.principalName(PRINCIPAL_NAME)
|
||||
.accessToken(ACCESS_TOKEN)
|
||||
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
|
||||
.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
|
||||
.build();
|
||||
|
||||
assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
|
||||
assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
|
||||
assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
|
||||
assertThat(authorization.getAttributes()).containsExactly(
|
||||
entry(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE));
|
||||
entry(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE));
|
||||
}
|
||||
}
|
||||
|
@ -46,40 +46,4 @@ public class TestRegisteredClients {
|
||||
.scope("profile")
|
||||
.scope("email");
|
||||
}
|
||||
|
||||
public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() {
|
||||
return RegisteredClient.withId("valid_client_id")
|
||||
.clientId("valid_client")
|
||||
.clientSecret("valid_secret")
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||
.redirectUri("http://localhost:8080/test-application/callback")
|
||||
.scope("openid")
|
||||
.scope("profile")
|
||||
.scope("email");
|
||||
}
|
||||
|
||||
public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() {
|
||||
return RegisteredClient.withId("valid_client_multi_uri_id")
|
||||
.clientId("valid_client_multi_uri")
|
||||
.clientSecret("valid_secret")
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||
.redirectUri("http://localhost:8080/test-application/callback")
|
||||
.redirectUri("http://localhost:8080/another-test-application/callback")
|
||||
.scope("openid")
|
||||
.scope("profile")
|
||||
.scope("email");
|
||||
}
|
||||
|
||||
public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() {
|
||||
return RegisteredClient.withId("valid_cc_client_id")
|
||||
.clientId("valid_cc_client")
|
||||
.clientSecret("valid_secret")
|
||||
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
|
||||
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||
.scope("openid")
|
||||
.scope("profile")
|
||||
.scope("email");
|
||||
}
|
||||
}
|
||||
|
@ -1,371 +0,0 @@
|
||||
/*
|
||||
* Copyright 2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.oauth2.server.authorization.web;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.crypto.keygen.StringKeyGenerator;
|
||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
||||
|
||||
|
||||
/**
|
||||
* Tests for {@link OAuth2AuthorizationEndpointFilter}.
|
||||
*
|
||||
* @author Paurav Munshi
|
||||
* @since 0.0.1
|
||||
*/
|
||||
|
||||
public class OAuth2AuthorizationEndpointFilterTest {
|
||||
|
||||
private static final String VALID_CLIENT = "valid_client";
|
||||
private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri";
|
||||
private static final String VALID_CC_CLIENT = "valid_cc_client";
|
||||
|
||||
private OAuth2AuthorizationEndpointFilter filter;
|
||||
|
||||
private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
|
||||
private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class);
|
||||
private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class);
|
||||
private Authentication authentication = mock(Authentication.class);
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
|
||||
this.filter.setCodeGenerator(this.codeGenerator);
|
||||
|
||||
SecurityContextHolder.getContext().setAuthentication(this.authentication);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
|
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
|
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||
assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||
assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||
assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||
assertThatThrownBy(() ->this.filter.setCodeGenerator(null))
|
||||
.isInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
|
||||
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
|
||||
when(this.authentication.getPrincipal()).thenReturn("test-user");
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication).isAuthenticated();
|
||||
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
|
||||
verify(this.authorizationService).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
|
||||
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
|
||||
when(this.authentication.getPrincipal()).thenReturn("test-user");
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication).isAuthenticated();
|
||||
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
|
||||
verify(this.authorizationService).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI);
|
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build();
|
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient);
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication, times(1)).isAuthenticated();
|
||||
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI);
|
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator, times(0)).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication, times(1)).isAuthenticated();
|
||||
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT);
|
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator, times(0)).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient);
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication, times(1)).isAuthenticated();
|
||||
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT);
|
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator, times(0)).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication).isAuthenticated();
|
||||
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
|
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator, times(0)).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
|
||||
assertThat(response.getContentAsString()).isEmpty();
|
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null);
|
||||
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication).isAuthenticated();
|
||||
verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client");
|
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator, times(0)).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
|
||||
assertThat(response.getContentAsString()).isEmpty();
|
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
when(authentication.isAuthenticated()).thenReturn(false);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authentication).isAuthenticated();
|
||||
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
|
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||
verify(this.codeGenerator, times(0)).generateKey();
|
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
|
||||
assertThat(response.getContentAsString()).isEmpty();
|
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setServletPath("/custom/authorize");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
|
||||
spyFilter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
|
||||
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
|
||||
spyFilter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
|
||||
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
|
||||
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
|
||||
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
|
||||
spyFilter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
|
||||
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
|
||||
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
private MockHttpServletRequest getValidMockHttpServletRequest() {
|
||||
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT);
|
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code");
|
||||
request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email");
|
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback");
|
||||
request.setParameter(OAuth2ParameterNames.STATE, "teststate");
|
||||
request.setServletPath("/oauth2/authorize");
|
||||
|
||||
return request;
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,399 @@
|
||||
/*
|
||||
* Copyright 2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.oauth2.server.authorization.web;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
/**
|
||||
* Tests for {@link OAuth2AuthorizationEndpointFilter}.
|
||||
*
|
||||
* @author Paurav Munshi
|
||||
* @author Joe Grandja
|
||||
* @since 0.0.1
|
||||
*/
|
||||
public class OAuth2AuthorizationEndpointFilterTests {
|
||||
private RegisteredClientRepository registeredClientRepository;
|
||||
private OAuth2AuthorizationService authorizationService;
|
||||
private OAuth2AuthorizationEndpointFilter filter;
|
||||
private TestingAuthenticationToken authentication;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
this.registeredClientRepository = mock(RegisteredClientRepository.class);
|
||||
this.authorizationService = mock(OAuth2AuthorizationService.class);
|
||||
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
|
||||
this.authentication = new TestingAuthenticationToken("principalName", "password");
|
||||
this.authentication.setAuthenticated(true);
|
||||
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
||||
securityContext.setAuthentication(this.authentication);
|
||||
SecurityContextHolder.setContext(securityContext);
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
SecurityContextHolder.clearContext();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("registeredClientRepository cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("authorizationService cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService, null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("authorizationEndpointUri cannot be empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception {
|
||||
String requestUri = "/path";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Exception {
|
||||
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestNotAuthenticatedThenNotProcessed() throws Exception {
|
||||
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.authentication.setAuthenticated(false);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestAndClientNotAuthorizedToRequestCodeThenUnauthorizedClientError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
|
||||
.authorizationGrantTypes(Set::clear)
|
||||
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
|
||||
.build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestInvalidRedirectUriThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestExcludesRedirectUriAndMultipleRegisteredThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().redirectUri("https://example2.com").build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.removeParameter(OAuth2ParameterNames.REDIRECT_URI);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestMissingResponseTypeThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
|
||||
"error=invalid_request&" +
|
||||
"error_description=OAuth%202.0%20Parameter:%20response_type&" +
|
||||
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
|
||||
"state=state");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestMultipleResponseTypeThenInvalidRequestError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
|
||||
"error=invalid_request&" +
|
||||
"error_description=OAuth%202.0%20Parameter:%20response_type&" +
|
||||
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
|
||||
"state=state");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestInvalidResponseTypeThenUnsupportedResponseTypeError() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
|
||||
"error=unsupported_response_type&" +
|
||||
"error_description=OAuth%202.0%20Parameter:%20response_type&" +
|
||||
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
|
||||
"state=state");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
|
||||
.thenReturn(registeredClient);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
|
||||
|
||||
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
|
||||
|
||||
verify(this.authorizationService).save(authorizationCaptor.capture());
|
||||
|
||||
OAuth2Authorization authorization = authorizationCaptor.getValue();
|
||||
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
|
||||
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
|
||||
|
||||
String code = authorization.getAttribute(OAuth2ParameterNames.class.getName().concat(".CODE"));
|
||||
assertThat(code).isNotNull();
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
|
||||
assertThat(authorizationRequest).isNotNull();
|
||||
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize");
|
||||
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
|
||||
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
|
||||
assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
|
||||
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(registeredClient.getRedirectUris().iterator().next());
|
||||
assertThat(authorizationRequest.getScopes()).containsExactlyInAnyOrderElementsOf(registeredClient.getScopes());
|
||||
assertThat(authorizationRequest.getState()).isEqualTo("state");
|
||||
assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
|
||||
}
|
||||
|
||||
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
|
||||
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);
|
||||
|
||||
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
|
||||
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
|
||||
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
|
||||
request.addParameter(OAuth2ParameterNames.SCOPE,
|
||||
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
return request;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user