Polish gh-77

This commit is contained in:
Joe Grandja 2020-05-19 16:55:50 -04:00
parent 54e219a397
commit fbc98d511c
9 changed files with 625 additions and 636 deletions

View File

@ -15,6 +15,7 @@
*/ */
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.util.List; import java.util.List;
@ -65,7 +66,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) { private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) {
if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue())); return token.equals(authorization.getAttributes().get(OAuth2ParameterNames.class.getName().concat(".CODE")));
} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) { } else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
return authorization.getAccessToken() != null && return authorization.getAccessToken() != null &&
authorization.getAccessToken().getTokenValue().equals(token); authorization.getAccessToken().getTokenValue().equals(token);

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -196,7 +197,7 @@ public class OAuth2Authorization implements Serializable {
*/ */
public OAuth2Authorization build() { public OAuth2Authorization build() {
Assert.hasText(this.principalName, "principalName cannot be empty"); Assert.hasText(this.principalName, "principalName cannot be empty");
Assert.notNull(this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()), "authorization code cannot be null"); Assert.notNull(this.attributes.get(OAuth2ParameterNames.class.getName().concat(".CODE")), "authorization code cannot be null");
OAuth2Authorization authorization = new OAuth2Authorization(); OAuth2Authorization authorization = new OAuth2Authorization();
authorization.registeredClientId = this.registeredClientId; authorization.registeredClientId = this.registeredClientId;

View File

@ -15,29 +15,21 @@
*/ */
package org.springframework.security.oauth2.server.authorization.web; package org.springframework.security.oauth2.server.authorization.web;
import java.io.IOException; import org.springframework.http.HttpMethod;
import java.util.stream.Stream;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.DefaultRedirectStrategy;
@ -45,201 +37,257 @@ import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/** /**
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
* which handles the processing of the OAuth 2.0 Authorization Request.
*
* @author Joe Grandja * @author Joe Grandja
* @author Paurav Munshi * @author Paurav Munshi
* @since 0.0.1 * @since 0.0.1
* @see RegisteredClientRepository
* @see OAuth2AuthorizationService
* @see OAuth2Authorization
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Request</a>
*/ */
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
/**
* The default endpoint {@code URI} for authorization requests.
*/
public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; private final RegisteredClientRepository registeredClientRepository;
private final OAuth2AuthorizationService authorizationService;
private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter = new OAuth2AuthorizationRequestConverter(); private final RequestMatcher authorizationEndpointMatcher;
private RegisteredClientRepository registeredClientRepository; private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private OAuth2AuthorizationService authorizationService; private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator();
private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT);
/**
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
*
* @param registeredClientRepository the repository of registered clients
* @param authorizationService the authorization service
*/
public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository, public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
OAuth2AuthorizationService authorizationService) { OAuth2AuthorizationService authorizationService) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null."); this(registeredClientRepository, authorizationService, DEFAULT_AUTHORIZATION_ENDPOINT_URI);
Assert.notNull(authorizationService, "authorizationService cannot be null."); }
/**
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
*
* @param registeredClientRepository the repository of registered clients
* @param authorizationService the authorization service
* @param authorizationEndpointUri the endpoint {@code URI} for authorization requests
*/
public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
OAuth2AuthorizationService authorizationService, String authorizationEndpointUri) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
Assert.notNull(authorizationService, "authorizationService cannot be null");
Assert.hasText(authorizationEndpointUri, "authorizationEndpointUri cannot be empty");
this.registeredClientRepository = registeredClientRepository; this.registeredClientRepository = registeredClientRepository;
this.authorizationService = authorizationService; this.authorizationService = authorizationService;
} this.authorizationEndpointMatcher = new AntPathRequestMatcher(
authorizationEndpointUri, HttpMethod.GET.name());
public final void setAuthorizationRequestConverter(
Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter) {
Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null");
this.authorizationRequestConverter = authorizationRequestConverter;
}
public final void setCodeGenerator(StringKeyGenerator codeGenerator) {
Assert.notNull(codeGenerator, "codeGenerator cannot be set to null");
this.codeGenerator = codeGenerator;
}
public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null");
this.authorizationRedirectStrategy = authorizationRedirectStrategy;
}
public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) {
Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null");
this.authorizationEndpointMatcher = authorizationEndpointMatcher;
} }
@Override @Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
boolean pathMatch = this.authorizationEndpointMatcher.matches(request);
String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType);
if (pathMatch && responseTypeMatch) {
return false;
}else {
return true;
}
}
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
RegisteredClient client = null; if (!this.authorizationEndpointMatcher.matches(request) || !isPrincipalAuthenticated()) {
OAuth2AuthorizationRequest authorizationRequest = null; filterChain.doFilter(request, response);
OAuth2Authorization authorization = null; return;
}
try { // TODO
checkUserAuthenticated(); // The authorization server validates the request to ensure that all
Authentication auth = SecurityContextHolder.getContext().getAuthentication(); // required parameters are present and valid. If the request is valid,
client = fetchRegisteredClient(request); // the authorization server authenticates the resource owner and obtains
// an authorization decision (by asking the resource owner or by
// establishing approval via other means).
authorizationRequest = this.authorizationRequestConverter.convert(request); MultiValueMap<String, String> parameters = getParameters(request);
validateAuthorizationRequest(authorizationRequest, client); String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
// client_id (REQUIRED)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId) ||
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
return;
}
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
return;
} else if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
OAuth2Error error = createError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID);
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
return;
}
// redirect_uri (OPTIONAL)
String redirectUriParameter = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
if (StringUtils.hasText(redirectUriParameter)) {
if (!registeredClient.getRedirectUris().contains(redirectUriParameter) ||
parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
return;
}
} else if (registeredClient.getRedirectUris().size() != 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
return;
}
String redirectUri = StringUtils.hasText(redirectUriParameter) ?
redirectUriParameter : registeredClient.getRedirectUris().iterator().next();
// response_type (REQUIRED)
String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);
if (!StringUtils.hasText(responseType) ||
parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
} else if (!responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) {
OAuth2Error error = createError(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2ParameterNames.RESPONSE_TYPE);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}
Authentication principal = SecurityContextHolder.getContext().getAuthentication();
String code = this.codeGenerator.generateKey(); String code = this.codeGenerator.generateKey();
authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code); OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest(request);
this.authorizationService.save(authorization);
String redirectUri = getRedirectUri(authorizationRequest, client); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code); .principalName(principal.getName())
} .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), code)
catch(OAuth2AuthorizationException authorizationException) { .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
OAuth2Error authorizationError = authorizationException.getError();
if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
sendErrorInResponse(response, authorizationError);
}
else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) {
String redirectUri = getRedirectUri(authorizationRequest, client);
sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri);
}
else {
throw new ServletException(authorizationException);
}
}
}
private void checkUserAuthenticated() {
Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication();
if (currentAuth==null || !currentAuth.isAuthenticated()) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
}
}
private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException {
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
if (StringUtils.isEmpty(clientId)) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}
RegisteredClient client = this.registeredClientRepository.findByClientId(clientId);
if (client==null) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
}
boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes())
.anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE));
if (!isAuthorizationGrantAllowed) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
}
return client;
}
private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client,
OAuth2AuthorizationRequest authorizationRequest, String code) {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client)
.principalName(auth.getPrincipal().toString())
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), code)
.attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes()))
.build(); .build();
return authorization; this.authorizationService.save(authorization);
// TODO security checks for code parameter
// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks.
// A maximum authorization code lifetime of 10 minutes is RECOMMENDED.
// The client MUST NOT use the authorization code more than once.
// If an authorization code is used more than once, the authorization server MUST deny the request
// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code.
// The authorization code is bound to the client identifier and redirection URI.
sendAuthorizationResponse(request, response, authorizationRequest, code, redirectUri);
} }
private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response,
OAuth2AuthorizationRequest authorizationRequest, String code, String redirectUri) throws IOException {
private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { UriComponentsBuilder uriBuilder = UriComponentsBuilder
String redirectUri = authorizationRequest.getRedirectUri(); .fromUriString(redirectUri)
if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}
if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}
}
private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
return !StringUtils.isEmpty(authorizationRequest.getRedirectUri())
? authorizationRequest.getRedirectUri()
: client.getRedirectUris().stream().findFirst().get();
}
private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response,
OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException {
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
.queryParam(OAuth2ParameterNames.CODE, code); .queryParam(OAuth2ParameterNames.CODE, code);
if (!StringUtils.isEmpty(authorizationRequest.getState())) { if (StringUtils.hasText(authorizationRequest.getState())) {
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
}
this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
} }
String finalRedirectUri = redirectUriBuilder.toUriString(); private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri); OAuth2Error error, String state, String redirectUri) throws IOException {
if (redirectUri == null) {
// TODO Send default html error response
response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString());
return;
} }
private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { UriComponentsBuilder uriBuilder = UriComponentsBuilder
int errorStatus = -1; .fromUriString(redirectUri)
String errorCode = authorizationError.getErrorCode(); .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) { if (StringUtils.hasText(error.getDescription())) {
errorStatus=HttpStatus.FORBIDDEN.value(); uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
} }
else { if (StringUtils.hasText(error.getUri())) {
errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value(); uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
} }
response.sendError(errorStatus, authorizationError.getErrorCode()); if (StringUtils.hasText(state)) {
uriBuilder.queryParam(OAuth2ParameterNames.STATE, state);
}
this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
} }
private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, private static boolean isPrincipalAuthenticated() {
OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError, return isPrincipalAuthenticated(SecurityContextHolder.getContext().getAuthentication());
String redirectUri) throws IOException {
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
.queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode());
if (!StringUtils.isEmpty(authorizationRequest.getState())) {
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
} }
String finalRedirectURI = redirectUriBuilder.toUriString(); private static boolean isPrincipalAuthenticated(Authentication principal) {
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); return principal != null &&
!AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) &&
principal.isAuthenticated();
}
private static OAuth2Error createError(String errorCode, String parameterName) {
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
}
private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
MultiValueMap<String, String> parameters = getParameters(request);
Set<String> scopes = Collections.emptySet();
if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) {
String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
}
return OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(request.getRequestURL().toString())
.clientId(parameters.getFirst(OAuth2ParameterNames.CLIENT_ID))
.redirectUri(parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI))
.scopes(scopes)
.state(parameters.getFirst(OAuth2ParameterNames.STATE))
.additionalParameters(additionalParameters ->
parameters.entrySet().stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.RESPONSE_TYPE) &&
!e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI) &&
!e.getKey().equals(OAuth2ParameterNames.SCOPE) &&
!e.getKey().equals(OAuth2ParameterNames.STATE))
.forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0))))
.build();
}
private static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
parameterMap.forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
} }
} }

View File

@ -1,55 +0,0 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.http.HttpServletRequest;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.StringUtils;
/**
* @author Paurav Munshi
* @since 0.0.1
* @see Converter
*/
public class OAuth2AuthorizationRequestConverter implements Converter<HttpServletRequest, OAuth2AuthorizationRequest> {
@Override
public OAuth2AuthorizationRequest convert(HttpServletRequest request) {
String scope = request.getParameter(OAuth2ParameterNames.SCOPE);
Set<String> scopes = !StringUtils.isEmpty(scope)
? new LinkedHashSet<String>(Arrays.asList(scope.split(" ")))
: Collections.emptySet();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID))
.redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
.scopes(scopes)
.state(request.getParameter(OAuth2ParameterNames.STATE))
.authorizationUri(request.getServletPath())
.build();
return authorizationRequest;
}
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
@ -61,7 +62,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void saveWhenAuthorizationProvidedThenSaved() { public void saveWhenAuthorizationProvidedThenSaved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.build(); .build();
this.authorizationService.save(expectedAuthorization); this.authorizationService.save(expectedAuthorization);
@ -88,7 +89,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() { public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.build(); .build();
this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization)); this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
@ -103,7 +104,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
"access-token", Instant.now().minusSeconds(60), Instant.now()); "access-token", Instant.now().minusSeconds(60), Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.accessToken(accessToken) .accessToken(accessToken)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);

View File

@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
@ -84,13 +85,13 @@ public class OAuth2AuthorizationTests {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.accessToken(ACCESS_TOKEN) .accessToken(ACCESS_TOKEN)
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.build(); .build();
assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN); assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
assertThat(authorization.getAttributes()).containsExactly( assertThat(authorization.getAttributes()).containsExactly(
entry(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)); entry(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE));
} }
} }

View File

@ -46,40 +46,4 @@ public class TestRegisteredClients {
.scope("profile") .scope("profile")
.scope("email"); .scope("email");
} }
public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() {
return RegisteredClient.withId("valid_client_id")
.clientId("valid_client")
.clientSecret("valid_secret")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.redirectUri("http://localhost:8080/test-application/callback")
.scope("openid")
.scope("profile")
.scope("email");
}
public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() {
return RegisteredClient.withId("valid_client_multi_uri_id")
.clientId("valid_client_multi_uri")
.clientSecret("valid_secret")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.redirectUri("http://localhost:8080/test-application/callback")
.redirectUri("http://localhost:8080/another-test-application/callback")
.scope("openid")
.scope("profile")
.scope("email");
}
public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() {
return RegisteredClient.withId("valid_cc_client_id")
.clientId("valid_cc_client")
.clientSecret("valid_secret")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.scope("openid")
.scope("profile")
.scope("email");
}
} }

View File

@ -1,371 +0,0 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
/**
* Tests for {@link OAuth2AuthorizationEndpointFilter}.
*
* @author Paurav Munshi
* @since 0.0.1
*/
public class OAuth2AuthorizationEndpointFilterTest {
private static final String VALID_CLIENT = "valid_client";
private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri";
private static final String VALID_CC_CLIENT = "valid_cc_client";
private OAuth2AuthorizationEndpointFilter filter;
private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class);
private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class);
private Authentication authentication = mock(Authentication.class);
@Before
public void setUp() {
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
this.filter.setCodeGenerator(this.codeGenerator);
SecurityContextHolder.getContext().setAuthentication(this.authentication);
}
@Test
public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setCodeGenerator(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
when(this.authentication.getPrincipal()).thenReturn("test-user");
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
verify(this.authorizationService).save(any(OAuth2Authorization.class));
verify(this.codeGenerator).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
}
@Test
public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
when(this.authentication.getPrincipal()).thenReturn("test-user");
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
verify(this.authorizationService).save(any(OAuth2Authorization.class));
verify(this.codeGenerator).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
}
@Test
public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI);
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication, times(1)).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI);
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication, times(1)).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT);
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication, times(1)).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT);
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
}
@Test
public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
assertThat(response.getContentAsString()).isEmpty();
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null);
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client");
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
assertThat(response.getContentAsString()).isEmpty();
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
}
@Test
public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
when(authentication.isAuthenticated()).thenReturn(false);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
assertThat(response.getContentAsString()).isEmpty();
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
}
@Test
public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setServletPath("/custom/authorize");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
spyFilter.doFilter(request, response, filterChain);
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
}
@Test
public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
spyFilter.doFilter(request, response, filterChain);
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
spyFilter.doFilter(request, response, filterChain);
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
private MockHttpServletRequest getValidMockHttpServletRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT);
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code");
request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email");
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback");
request.setParameter(OAuth2ParameterNames.STATE, "teststate");
request.setServletPath("/oauth2/authorize");
return request;
}
}

View File

@ -0,0 +1,399 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.util.StringUtils;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
/**
* Tests for {@link OAuth2AuthorizationEndpointFilter}.
*
* @author Paurav Munshi
* @author Joe Grandja
* @since 0.0.1
*/
public class OAuth2AuthorizationEndpointFilterTests {
private RegisteredClientRepository registeredClientRepository;
private OAuth2AuthorizationService authorizationService;
private OAuth2AuthorizationEndpointFilter filter;
private TestingAuthenticationToken authentication;
@Before
public void setUp() {
this.registeredClientRepository = mock(RegisteredClientRepository.class);
this.authorizationService = mock(OAuth2AuthorizationService.class);
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
this.authentication = new TestingAuthenticationToken("principalName", "password");
this.authentication.setAuthenticated(true);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(this.authentication);
SecurityContextHolder.setContext(securityContext);
}
@After
public void cleanup() {
SecurityContextHolder.clearContext();
}
@Test
public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("registeredClientRepository cannot be null");
}
@Test
public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationService cannot be null");
}
@Test
public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationEndpointUri cannot be empty");
}
@Test
public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Exception {
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationRequestNotAuthenticatedThenNotProcessed() throws Exception {
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.authentication.setAuthenticated(false);
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
}
@Test
public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
}
@Test
public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
}
@Test
public void doFilterWhenAuthorizationRequestAndClientNotAuthorizedToRequestCodeThenUnauthorizedClientError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
.authorizationGrantTypes(Set::clear)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id");
}
@Test
public void doFilterWhenAuthorizationRequestInvalidRedirectUriThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
}
@Test
public void doFilterWhenAuthorizationRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
}
@Test
public void doFilterWhenAuthorizationRequestExcludesRedirectUriAndMultipleRegisteredThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().redirectUri("https://example2.com").build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.removeParameter(OAuth2ParameterNames.REDIRECT_URI);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
}
@Test
public void doFilterWhenAuthorizationRequestMissingResponseTypeThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20response_type&" +
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
"state=state");
}
@Test
public void doFilterWhenAuthorizationRequestMultipleResponseTypeThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20response_type&" +
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
"state=state");
}
@Test
public void doFilterWhenAuthorizationRequestInvalidResponseTypeThenUnsupportedResponseTypeError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=unsupported_response_type&" +
"error_description=OAuth%202.0%20Parameter:%20response_type&" +
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
"state=state");
}
@Test
public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization authorization = authorizationCaptor.getValue();
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
String code = authorization.getAttribute(OAuth2ParameterNames.class.getName().concat(".CODE"));
assertThat(code).isNotNull();
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize");
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(registeredClient.getRedirectUris().iterator().next());
assertThat(authorizationRequest.getScopes()).containsExactlyInAnyOrderElementsOf(registeredClient.getScopes());
assertThat(authorizationRequest.getState()).isEqualTo("state");
assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
}
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
request.addParameter(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
request.addParameter(OAuth2ParameterNames.STATE, "state");
return request;
}
}