From aa5133e170472aee88cde6d0be6608c7035cfb5a Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 22 Sep 2020 11:57:50 -0400 Subject: [PATCH] Add user consent page Closes gh-42 --- .../OAuth2AuthorizationServerSecurity.java | 18 +- .../OAuth2AuthorizationServerConfigurer.java | 10 +- .../InMemoryOAuth2AuthorizationService.java | 12 +- .../OAuth2AuthorizationAttributeNames.java | 10 + .../OAuth2AuthorizationService.java | 7 + ...thorizationCodeAuthenticationProvider.java | 4 +- .../authorization/config/ClientSettings.java | 24 + .../OAuth2AuthorizationEndpointFilter.java | 631 ++++++++++++++---- ...MemoryOAuth2AuthorizationServiceTests.java | 48 +- .../TestOAuth2Authorizations.java | 4 +- ...zationCodeAuthenticationProviderTests.java | 10 + .../config/ClientSettingsTests.java | 9 +- ...Auth2AuthorizationEndpointFilterTests.java | 386 ++++++++++- .../config/AuthorizationServerConfig.java | 2 + .../sample/web/AuthorizationController.java | 21 + .../src/main/resources/templates/index.html | 4 + 16 files changed, 1019 insertions(+), 181 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerSecurity.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerSecurity.java index 7c68ed0..af5e45a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerSecurity.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerSecurity.java @@ -17,11 +17,8 @@ package org.springframework.security.config.annotation.web.configuration; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; -import org.springframework.http.HttpMethod; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization.OAuth2AuthorizationServerConfigurer; -import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; -import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -41,22 +38,17 @@ public class OAuth2AuthorizationServerSecurity extends WebSecurityConfigurerAdap protected void configure(HttpSecurity http) throws Exception { OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = new OAuth2AuthorizationServerConfigurer<>(); + RequestMatcher[] endpointMatchers = authorizationServerConfigurer + .getEndpointMatchers().toArray(new RequestMatcher[0]); http - .requestMatcher(new OrRequestMatcher(authorizationServerConfigurer.getEndpointMatchers())) + .requestMatcher(new OrRequestMatcher(endpointMatchers)) .authorizeRequests(authorizeRequests -> - authorizeRequests - .anyRequest().authenticated() + authorizeRequests.anyRequest().authenticated() ) .formLogin(withDefaults()) - .csrf(csrf -> csrf.ignoringRequestMatchers(tokenEndpointMatcher())) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointMatchers)) .apply(authorizationServerConfigurer); } // @formatter:on - - private static RequestMatcher tokenEndpointMatcher() { - return new AntPathRequestMatcher( - OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI, - HttpMethod.POST.name()); - } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index 778f31a..4104c80 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -40,6 +40,7 @@ import org.springframework.security.web.access.intercept.FilterSecurityIntercept import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -63,8 +64,13 @@ import java.util.Map; public final class OAuth2AuthorizationServerConfigurer> extends AbstractHttpConfigurer, B> { - private final RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher( - OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI, HttpMethod.GET.name()); + private final RequestMatcher authorizationEndpointMatcher = new OrRequestMatcher( + new AntPathRequestMatcher( + OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI, + HttpMethod.GET.name()), + new AntPathRequestMatcher( + OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI, + HttpMethod.POST.name())); private final RequestMatcher tokenEndpointMatcher = new AntPathRequestMatcher( OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI, HttpMethod.POST.name()); private final RequestMatcher jwkSetEndpointMatcher = new AntPathRequestMatcher( diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 87e8033..3d24bb3 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -43,6 +43,14 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza this.authorizations.put(authorizationId, authorization); } + @Override + public void remove(OAuth2Authorization authorization) { + Assert.notNull(authorization, "authorization cannot be null"); + OAuth2AuthorizationId authorizationId = new OAuth2AuthorizationId( + authorization.getRegisteredClientId(), authorization.getPrincipalName()); + this.authorizations.remove(authorizationId, authorization); + } + @Override public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) { Assert.hasText(token, "token cannot be empty"); @@ -53,7 +61,9 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza } private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) { - if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { + if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) { + return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)); + } else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); } else if (TokenType.ACCESS_TOKEN.equals(tokenType)) { return authorization.getAccessToken() != null && diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java index d691a6e..06440b0 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java @@ -30,6 +30,11 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; */ public interface OAuth2AuthorizationAttributeNames { + /** + * The name of the attribute used for correlating the user consent request/response. + */ + String STATE = OAuth2Authorization.class.getName().concat(".STATE"); + /** * The name of the attribute used for the {@link OAuth2ParameterNames#CODE} parameter. */ @@ -40,6 +45,11 @@ public interface OAuth2AuthorizationAttributeNames { */ String AUTHORIZATION_REQUEST = OAuth2Authorization.class.getName().concat(".AUTHORIZATION_REQUEST"); + /** + * The name of the attribute used for the authorized scope(s). + */ + String AUTHORIZED_SCOPES = OAuth2Authorization.class.getName().concat(".AUTHORIZED_SCOPES"); + /** * The name of the attribute used for the attributes/claims of the {@link OAuth2AccessToken}. */ diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java index 0151e36..1cc2b0f 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java @@ -34,6 +34,13 @@ public interface OAuth2AuthorizationService { */ void save(OAuth2Authorization authorization); + /** + * Removes the {@link OAuth2Authorization}. + * + * @param authorization the {@link OAuth2Authorization} + */ + void remove(OAuth2Authorization authorization); + /** * Returns the {@link OAuth2Authorization} containing the provided {@code token}, * or {@code null} if not found. diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index b0f8a29..c7a2053 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -44,6 +44,7 @@ import java.net.URL; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.Set; /** * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant. @@ -123,6 +124,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token time-to-live + Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() .issuer(issuer) @@ -131,7 +133,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica .issuedAt(issuedAt) .expiresAt(expiresAt) .notBefore(issuedAt) - .claim(OAuth2ParameterNames.SCOPE, authorizationRequest.getScopes()) + .claim(OAuth2ParameterNames.SCOPE, authorizedScopes) .build(); Jwt jwt = this.jwtEncoder.encode(joseHeader, jwtClaimsSet); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ClientSettings.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ClientSettings.java index b22c909..323933b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ClientSettings.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ClientSettings.java @@ -28,6 +28,7 @@ import java.util.Map; public class ClientSettings extends Settings { private static final String CLIENT_SETTING_BASE = "spring.security.oauth2.authorization-server.client."; public static final String REQUIRE_PROOF_KEY = CLIENT_SETTING_BASE.concat("require-proof-key"); + public static final String REQUIRE_USER_CONSENT = CLIENT_SETTING_BASE.concat("require-user-consent"); /** * Constructs a {@code ClientSettings}. @@ -67,9 +68,32 @@ public class ClientSettings extends Settings { return this; } + /** + * Returns {@code true} if the user's consent is required when the client requests access. + * The default is {@code false}. + * + * @return {@code true} if the user's consent is required when the client requests access, {@code false} otherwise + */ + public boolean requireUserConsent() { + return setting(REQUIRE_USER_CONSENT); + } + + /** + * Set to {@code true} if the user's consent is required when the client requests access. + * This applies to all interactive flows (e.g. {@code authorization_code} and {@code device_code}). + * + * @param requireUserConsent {@code true} if the user's consent is required when the client requests access, {@code false} otherwise + * @return the {@link ClientSettings} + */ + public ClientSettings requireUserConsent(boolean requireUserConsent) { + setting(REQUIRE_USER_CONSENT, requireUserConsent); + return this; + } + protected static Map defaultSettings() { Map settings = new HashMap<>(); settings.put(REQUIRE_PROOF_KEY, false); + settings.put(REQUIRE_USER_CONSENT, false); return settings; } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 7c94fcc..06555ea 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization.web; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; @@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.web.DefaultRedirectStrategy; @@ -39,6 +41,7 @@ import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; @@ -49,10 +52,12 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; /** @@ -68,6 +73,7 @@ import java.util.Set; * @see OAuth2Authorization * @see Section 4.1 Authorization Code Grant * @see Section 4.1.1 Authorization Request + * @see Section 4.1.2 Authorization Response */ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { /** @@ -79,8 +85,10 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { private final RegisteredClientRepository registeredClientRepository; private final OAuth2AuthorizationService authorizationService; - private final RequestMatcher authorizationEndpointMatcher; - private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + private final RequestMatcher authorizationRequestMatcher; + private final RequestMatcher userConsentMatcher; + private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); + private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); /** @@ -108,102 +116,42 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { Assert.hasText(authorizationEndpointUri, "authorizationEndpointUri cannot be empty"); this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; - this.authorizationEndpointMatcher = new AntPathRequestMatcher( + this.authorizationRequestMatcher = new AntPathRequestMatcher( authorizationEndpointUri, HttpMethod.GET.name()); + this.userConsentMatcher = new AntPathRequestMatcher( + authorizationEndpointUri, HttpMethod.POST.name()); } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (!this.authorizationEndpointMatcher.matches(request)) { + if (this.authorizationRequestMatcher.matches(request)) { + processAuthorizationRequest(request, response, filterChain); + } else if (this.userConsentMatcher.matches(request)) { + processUserConsent(request, response); + } else { filterChain.doFilter(request, response); - return; } + } - // --------------- - // Validate the request to ensure that all required parameters are present and valid - // --------------- + private void processAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE); + OAuth2AuthorizationRequestContext authorizationRequestContext = + new OAuth2AuthorizationRequestContext( + request.getRequestURL().toString(), + OAuth2EndpointUtils.getParameters(request)); - // client_id (REQUIRED) - String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); - if (!StringUtils.hasText(clientId) || - parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); - sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect - return; - } - RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); - if (registeredClient == null) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); - sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect - return; - } else if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { - OAuth2Error error = createError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID); - sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect - return; - } + validateAuthorizationRequest(authorizationRequestContext); - // redirect_uri (OPTIONAL) - String redirectUriParameter = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI); - if (StringUtils.hasText(redirectUriParameter)) { - if (!registeredClient.getRedirectUris().contains(redirectUriParameter) || - parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); - sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect - return; + if (authorizationRequestContext.hasError()) { + if (authorizationRequestContext.isRedirectOnError()) { + sendErrorResponse(request, response, authorizationRequestContext.resolveRedirectUri(), + authorizationRequestContext.getError(), authorizationRequestContext.getState()); + } else { + sendErrorResponse(response, authorizationRequestContext.getError()); } - } else if (registeredClient.getRedirectUris().size() != 1) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); - sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect - return; - } - - String redirectUri = StringUtils.hasText(redirectUriParameter) ? - redirectUriParameter : registeredClient.getRedirectUris().iterator().next(); - - // response_type (REQUIRED) - String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE); - if (!StringUtils.hasText(responseType) || - parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE); - sendErrorResponse(request, response, error, stateParameter, redirectUri); - return; - } else if (!responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) { - OAuth2Error error = createError(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2ParameterNames.RESPONSE_TYPE); - sendErrorResponse(request, response, error, stateParameter, redirectUri); - return; - } - - // code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE) - String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE); - if (StringUtils.hasText(codeChallenge)) { - if (parameters.get(PkceParameterNames.CODE_CHALLENGE).size() != 1) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI); - sendErrorResponse(request, response, error, stateParameter, redirectUri); - return; - } - - String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD); - if (StringUtils.hasText(codeChallengeMethod) && - parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() != 1) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI); - sendErrorResponse(request, response, error, stateParameter, redirectUri); - return; - } - - if (StringUtils.hasText(codeChallengeMethod) && - (!"S256".equals(codeChallengeMethod) && !"plain".equals(codeChallengeMethod))) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI); - sendErrorResponse(request, response, error, stateParameter, redirectUri); - return; - } - } else if (registeredClient.getClientSettings().requireProofKey()) { - OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI); - sendErrorResponse(request, response, error, stateParameter, redirectUri); return; } @@ -219,48 +167,241 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { return; } - String code = this.codeGenerator.generateKey(); - OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest(request); - - OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) + RegisteredClient registeredClient = authorizationRequestContext.getRegisteredClient(); + OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest(); + OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(principal.getName()) - .attribute(OAuth2AuthorizationAttributeNames.CODE, code) - .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) - .build(); + .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest); + if (registeredClient.getClientSettings().requireUserConsent()) { + String state = this.stateGenerator.generateKey(); + OAuth2Authorization authorization = builder + .attribute(OAuth2AuthorizationAttributeNames.STATE, state) + .build(); + this.authorizationService.save(authorization); + + // TODO Need to remove 'in-flight' authorization if consent step is not completed (e.g. approved or cancelled) + + UserConsentPage.displayConsent(request, response, registeredClient, authorization); + } else { + String code = this.codeGenerator.generateKey(); + OAuth2Authorization authorization = builder + .attribute(OAuth2AuthorizationAttributeNames.CODE, code) + .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()) + .build(); + this.authorizationService.save(authorization); + +// TODO security checks for code parameter +// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks. +// A maximum authorization code lifetime of 10 minutes is RECOMMENDED. +// The client MUST NOT use the authorization code more than once. +// If an authorization code is used more than once, the authorization server MUST deny the request +// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code. +// The authorization code is bound to the client identifier and redirection URI. + + sendAuthorizationResponse(request, response, + authorizationRequestContext.resolveRedirectUri(), code, authorizationRequest.getState()); + } + } + + private void processUserConsent(HttpServletRequest request, HttpServletResponse response) + throws IOException { + + UserConsentRequestContext userConsentRequestContext = + new UserConsentRequestContext( + request.getRequestURL().toString(), + OAuth2EndpointUtils.getParameters(request)); + + validateUserConsentRequest(userConsentRequestContext); + + if (userConsentRequestContext.hasError()) { + if (userConsentRequestContext.isRedirectOnError()) { + sendErrorResponse(request, response, userConsentRequestContext.resolveRedirectUri(), + userConsentRequestContext.getError(), userConsentRequestContext.getState()); + } else { + sendErrorResponse(response, userConsentRequestContext.getError()); + } + return; + } + + if (!UserConsentPage.isConsentApproved(request)) { + this.authorizationService.remove(userConsentRequestContext.getAuthorization()); + OAuth2Error error = createError(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID); + sendErrorResponse(request, response, userConsentRequestContext.resolveRedirectUri(), + error, userConsentRequestContext.getAuthorizationRequest().getState()); + return; + } + + String code = this.codeGenerator.generateKey(); + OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization()) + .attributes(attrs -> { + attrs.remove(OAuth2AuthorizationAttributeNames.STATE); + attrs.put(OAuth2AuthorizationAttributeNames.CODE, code); + attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes()); + }) + .build(); this.authorizationService.save(authorization); -// TODO security checks for code parameter -// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks. -// A maximum authorization code lifetime of 10 minutes is RECOMMENDED. -// The client MUST NOT use the authorization code more than once. -// If an authorization code is used more than once, the authorization server MUST deny the request -// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code. -// The authorization code is bound to the client identifier and redirection URI. + sendAuthorizationResponse(request, response, userConsentRequestContext.resolveRedirectUri(), + code, userConsentRequestContext.getAuthorizationRequest().getState()); + } - sendAuthorizationResponse(request, response, authorizationRequest, code, redirectUri); + private void validateAuthorizationRequest(OAuth2AuthorizationRequestContext authorizationRequestContext) { + // --------------- + // Validate the request to ensure all required parameters are present and valid + // --------------- + + // client_id (REQUIRED) + if (!StringUtils.hasText(authorizationRequestContext.getClientId()) || + authorizationRequestContext.getParameters().get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID)); + return; + } + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId( + authorizationRequestContext.getClientId()); + if (registeredClient == null) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID)); + return; + } else if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID)); + return; + } + authorizationRequestContext.setRegisteredClient(registeredClient); + + // redirect_uri (OPTIONAL) + if (StringUtils.hasText(authorizationRequestContext.getRedirectUri())) { + if (!registeredClient.getRedirectUris().contains(authorizationRequestContext.getRedirectUri()) || + authorizationRequestContext.getParameters().get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI)); + return; + } + } else if (registeredClient.getRedirectUris().size() != 1) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI)); + return; + } + authorizationRequestContext.setRedirectOnError(true); + + // response_type (REQUIRED) + if (!StringUtils.hasText(authorizationRequestContext.getResponseType()) || + authorizationRequestContext.getParameters().get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE)); + return; + } else if (!authorizationRequestContext.getResponseType().equals(OAuth2AuthorizationResponseType.CODE.getValue())) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2ParameterNames.RESPONSE_TYPE)); + return; + } + + // scope (OPTIONAL) + Set requestedScopes = authorizationRequestContext.getScopes(); + Set allowedScopes = registeredClient.getScopes(); + if (!requestedScopes.isEmpty() && !allowedScopes.containsAll(requestedScopes)) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE)); + return; + } + + // code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE) + String codeChallenge = authorizationRequestContext.getParameters().getFirst(PkceParameterNames.CODE_CHALLENGE); + if (StringUtils.hasText(codeChallenge)) { + if (authorizationRequestContext.getParameters().get(PkceParameterNames.CODE_CHALLENGE).size() != 1) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI)); + return; + } + + String codeChallengeMethod = authorizationRequestContext.getParameters().getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD); + if (StringUtils.hasText(codeChallengeMethod)) { + if (authorizationRequestContext.getParameters().get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() != 1 || + (!"S256".equals(codeChallengeMethod) && !"plain".equals(codeChallengeMethod))) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI)); + return; + } + } + } else if (registeredClient.getClientSettings().requireProofKey()) { + authorizationRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI)); + return; + } + } + + private void validateUserConsentRequest(UserConsentRequestContext userConsentRequestContext) { + // --------------- + // Validate the request to ensure all required parameters are present and valid + // --------------- + + // state (REQUIRED) + if (!StringUtils.hasText(userConsentRequestContext.getState()) || + userConsentRequestContext.getParameters().get(OAuth2ParameterNames.STATE).size() != 1) { + userConsentRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE)); + return; + } + OAuth2Authorization authorization = this.authorizationService.findByToken( + userConsentRequestContext.getState(), new TokenType(OAuth2AuthorizationAttributeNames.STATE)); + if (authorization == null) { + userConsentRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE)); + return; + } + userConsentRequestContext.setAuthorization(authorization); + + // The 'in-flight' authorization must be associated to the current principal + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + if (!isPrincipalAuthenticated(principal) || !principal.getName().equals(authorization.getPrincipalName())) { + userConsentRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE)); + return; + } + + // client_id (REQUIRED) + if (!StringUtils.hasText(userConsentRequestContext.getClientId()) || + userConsentRequestContext.getParameters().get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { + userConsentRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID)); + return; + } + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId( + userConsentRequestContext.getClientId()); + if (registeredClient == null || !registeredClient.getId().equals(authorization.getRegisteredClientId())) { + userConsentRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID)); + return; + } + userConsentRequestContext.setRegisteredClient(registeredClient); + userConsentRequestContext.setRedirectOnError(true); + + // scope (OPTIONAL) + Set requestedScopes = userConsentRequestContext.getAuthorizationRequest().getScopes(); + Set authorizedScopes = userConsentRequestContext.getScopes(); + if (!authorizedScopes.isEmpty() && !requestedScopes.containsAll(authorizedScopes)) { + userConsentRequestContext.setError( + createError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE)); + return; + } } private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response, - OAuth2AuthorizationRequest authorizationRequest, String code, String redirectUri) throws IOException { + String redirectUri, String code, String state) throws IOException { UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(redirectUri) .queryParam(OAuth2ParameterNames.CODE, code); - if (StringUtils.hasText(authorizationRequest.getState())) { - uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + if (StringUtils.hasText(state)) { + uriBuilder.queryParam(OAuth2ParameterNames.STATE, state); } this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); } private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, - OAuth2Error error, String state, String redirectUri) throws IOException { - - if (redirectUri == null) { - // TODO Send default html error response - response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString()); - return; - } + String redirectUri, OAuth2Error error, String state) throws IOException { UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(redirectUri) @@ -277,6 +418,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); } + private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + // TODO Send default html error response + response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString()); + } + private static OAuth2Error createError(String errorCode, String parameterName) { return createError(errorCode, parameterName, "https://tools.ietf.org/html/rfc6749#section-4.1.2.1"); } @@ -291,29 +437,254 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { principal.isAuthenticated(); } - private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) { - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + private static class OAuth2AuthorizationRequestContext extends AbstractRequestContext { + private final String responseType; + private final String redirectUri; - Set scopes = Collections.emptySet(); - if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) { - String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); - scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); + private OAuth2AuthorizationRequestContext( + String authorizationUri, MultiValueMap parameters) { + super(authorizationUri, parameters, + parameters.getFirst(OAuth2ParameterNames.CLIENT_ID), + parameters.getFirst(OAuth2ParameterNames.STATE), + extractScopes(parameters)); + this.responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE); + this.redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI); } - return OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(request.getRequestURL().toString()) - .clientId(parameters.getFirst(OAuth2ParameterNames.CLIENT_ID)) - .redirectUri(parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)) - .scopes(scopes) - .state(parameters.getFirst(OAuth2ParameterNames.STATE)) - .additionalParameters(additionalParameters -> - parameters.entrySet().stream() - .filter(e -> !e.getKey().equals(OAuth2ParameterNames.RESPONSE_TYPE) && - !e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) && - !e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI) && - !e.getKey().equals(OAuth2ParameterNames.SCOPE) && - !e.getKey().equals(OAuth2ParameterNames.STATE)) - .forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0)))) - .build(); + private static Set extractScopes(MultiValueMap parameters) { + String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); + return StringUtils.hasText(scope) ? + new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))) : + Collections.emptySet(); + } + + private String getResponseType() { + return this.responseType; + } + + private String getRedirectUri() { + return this.redirectUri; + } + + protected String resolveRedirectUri() { + return StringUtils.hasText(getRedirectUri()) ? + getRedirectUri() : + getRegisteredClient().getRedirectUris().iterator().next(); + } + + private OAuth2AuthorizationRequest buildAuthorizationRequest() { + return OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(getAuthorizationUri()) + .clientId(getClientId()) + .redirectUri(getRedirectUri()) + .scopes(getScopes()) + .state(getState()) + .additionalParameters(additionalParameters -> + getParameters().entrySet().stream() + .filter(e -> !e.getKey().equals(OAuth2ParameterNames.RESPONSE_TYPE) && + !e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) && + !e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI) && + !e.getKey().equals(OAuth2ParameterNames.SCOPE) && + !e.getKey().equals(OAuth2ParameterNames.STATE)) + .forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0)))) + .build(); + } + } + + private static class UserConsentRequestContext extends AbstractRequestContext { + private OAuth2Authorization authorization; + + private UserConsentRequestContext( + String authorizationUri, MultiValueMap parameters) { + super(authorizationUri, parameters, + parameters.getFirst(OAuth2ParameterNames.CLIENT_ID), + parameters.getFirst(OAuth2ParameterNames.STATE), + extractScopes(parameters)); + } + + private static Set extractScopes(MultiValueMap parameters) { + List scope = parameters.get(OAuth2ParameterNames.SCOPE); + return !CollectionUtils.isEmpty(scope) ? new HashSet<>(scope) : Collections.emptySet(); + } + + private OAuth2Authorization getAuthorization() { + return this.authorization; + } + + private void setAuthorization(OAuth2Authorization authorization) { + this.authorization = authorization; + } + + protected String resolveRedirectUri() { + OAuth2AuthorizationRequest authorizationRequest = getAuthorizationRequest(); + return StringUtils.hasText(authorizationRequest.getRedirectUri()) ? + authorizationRequest.getRedirectUri() : + getRegisteredClient().getRedirectUris().iterator().next(); + } + + private OAuth2AuthorizationRequest getAuthorizationRequest() { + return getAuthorization().getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + } + } + + private abstract static class AbstractRequestContext { + private final String authorizationUri; + private final MultiValueMap parameters; + private final String clientId; + private final String state; + private final Set scopes; + private RegisteredClient registeredClient; + private OAuth2Error error; + private boolean redirectOnError; + + protected AbstractRequestContext(String authorizationUri, MultiValueMap parameters, + String clientId, String state, Set scopes) { + this.authorizationUri = authorizationUri; + this.parameters = parameters; + this.clientId = clientId; + this.state = state; + this.scopes = scopes; + } + + protected String getAuthorizationUri() { + return this.authorizationUri; + } + + protected MultiValueMap getParameters() { + return this.parameters; + } + + protected String getClientId() { + return this.clientId; + } + + protected String getState() { + return this.state; + } + + protected Set getScopes() { + return this.scopes; + } + + protected RegisteredClient getRegisteredClient() { + return this.registeredClient; + } + + protected void setRegisteredClient(RegisteredClient registeredClient) { + this.registeredClient = registeredClient; + } + + protected OAuth2Error getError() { + return this.error; + } + + protected void setError(OAuth2Error error) { + this.error = error; + } + + protected boolean hasError() { + return getError() != null; + } + + protected boolean isRedirectOnError() { + return this.redirectOnError; + } + + protected void setRedirectOnError(boolean redirectOnError) { + this.redirectOnError = redirectOnError; + } + + protected abstract String resolveRedirectUri(); + } + + private static class UserConsentPage { + private static final MediaType TEXT_HTML_UTF8 = new MediaType("text", "html", StandardCharsets.UTF_8); + private static final String CONSENT_ACTION_PARAMETER_NAME = "consent_action"; + private static final String CONSENT_ACTION_APPROVE = "approve"; + private static final String CONSENT_ACTION_CANCEL = "cancel"; + + private static void displayConsent(HttpServletRequest request, HttpServletResponse response, + RegisteredClient registeredClient, OAuth2Authorization authorization) throws IOException { + + String consentPage = generateConsentPage(request, registeredClient, authorization); + response.setContentType(TEXT_HTML_UTF8.toString()); + response.setContentLength(consentPage.getBytes(StandardCharsets.UTF_8).length); + response.getWriter().write(consentPage); + } + + private static boolean isConsentApproved(HttpServletRequest request) { + return CONSENT_ACTION_APPROVE.equalsIgnoreCase(request.getParameter(CONSENT_ACTION_PARAMETER_NAME)); + } + + private static boolean isConsentCancelled(HttpServletRequest request) { + return CONSENT_ACTION_CANCEL.equalsIgnoreCase(request.getParameter(CONSENT_ACTION_PARAMETER_NAME)); + } + + private static String generateConsentPage(HttpServletRequest request, + RegisteredClient registeredClient, OAuth2Authorization authorization) { + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + String state = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.STATE); + + StringBuilder builder = new StringBuilder(); + + builder.append(""); + builder.append(""); + builder.append(""); + builder.append(" "); + builder.append(" "); + builder.append(" "); + builder.append(" Consent required"); + builder.append(""); + builder.append(""); + builder.append("
"); + builder.append("
"); + builder.append("

Consent required

"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("

" + registeredClient.getClientId() + " wants to access your account " + authorization.getPrincipalName() + "

"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("

The following permissions are requested by the above app.
Please review these and consent if you approve.

"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append(" "); + builder.append(" "); + + for (String scope : authorizationRequest.getScopes()) { + builder.append("
"); + builder.append(" "); + builder.append(" "); + builder.append("
"); + } + + builder.append("
"); + builder.append(" "); + builder.append("
"); + builder.append("
"); + builder.append(" "); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append("

Your consent to provide access is required.
If you do not approve, click Cancel, in which case no information will be shared with the app.

"); + builder.append("
"); + builder.append("
"); + builder.append("
"); + builder.append(""); + builder.append(""); + + return builder.toString(); + } } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index 6dabb77..9a05ba8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -30,6 +30,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * Tests for {@link InMemoryOAuth2AuthorizationService}. * * @author Krisztian Toth + * @author Joe Grandja */ public class InMemoryOAuth2AuthorizationServiceTests { private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); @@ -63,14 +64,53 @@ public class InMemoryOAuth2AuthorizationServiceTests { } @Test - public void findByTokenAndTokenTypeWhenTokenNullThenThrowIllegalArgumentException() { + public void removeWhenAuthorizationNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizationService.remove(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorization cannot be null"); + } + + @Test + public void removeWhenAuthorizationProvidedThenRemoved() { + OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .principalName(PRINCIPAL_NAME) + .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .build(); + + this.authorizationService.save(expectedAuthorization); + OAuth2Authorization authorization = this.authorizationService.findByToken( + AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); + assertThat(authorization).isEqualTo(expectedAuthorization); + + this.authorizationService.remove(expectedAuthorization); + authorization = this.authorizationService.findByToken( + AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); + assertThat(authorization).isNull(); + } + + @Test + public void findByTokenWhenTokenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizationService.findByToken(null, TokenType.AUTHORIZATION_CODE)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("token cannot be empty"); } @Test - public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() { + public void findByTokenWhenTokenTypeStateThenFound() { + String state = "state"; + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .principalName(PRINCIPAL_NAME) + .attribute(OAuth2AuthorizationAttributeNames.STATE, state) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + state, new TokenType(OAuth2AuthorizationAttributeNames.STATE)); + assertThat(authorization).isEqualTo(result); + } + + @Test + public void findByTokenWhenTokenTypeAuthorizationCodeThenFound() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) @@ -83,7 +123,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { } @Test - public void findByTokenAndTokenTypeWhenTokenTypeAccessTokenThenFound() { + public void findByTokenWhenTokenTypeAccessTokenThenFound() { OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now().minusSeconds(60), Instant.now()); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) @@ -99,7 +139,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { } @Test - public void findByTokenAndTokenTypeWhenTokenDoesNotExistThenNull() { + public void findByTokenWhenTokenDoesNotExistThenNull() { OAuth2Authorization result = this.authorizationService.findByToken( "access-token", TokenType.ACCESS_TOKEN); assertThat(result).isNull(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 70a0305..1b623cb 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -46,6 +46,7 @@ public class TestOAuth2Authorizations { .authorizationUri("https://provider.com/oauth2/authorize") .clientId(registeredClient.getClientId()) .redirectUri(registeredClient.getRedirectUris().iterator().next()) + .scopes(registeredClient.getScopes()) .additionalParameters(authorizationRequestAdditionalParameters) .state("state") .build(); @@ -53,6 +54,7 @@ public class TestOAuth2Authorizations { .principalName("principal") .accessToken(accessToken) .attribute(OAuth2AuthorizationAttributeNames.CODE, "code") - .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest); + .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) + .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 4c6fce9..28f1ebe 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -22,9 +22,11 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.jose.JoseHeaderNames; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; @@ -38,6 +40,7 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -187,6 +190,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + ArgumentCaptor jwtClaimsSetCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class); + verify(this.jwtEncoder).encode(any(), jwtClaimsSetCaptor.capture()); + JwtClaimsSet jwtClaimsSet = jwtClaimsSetCaptor.getValue(); + + Set scopes = jwtClaimsSet.getClaim(OAuth2ParameterNames.SCOPE); + assertThat(scopes).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES)); + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ClientSettingsTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ClientSettingsTests.java index f61cad0..f4f1474 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ClientSettingsTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ClientSettingsTests.java @@ -30,8 +30,9 @@ public class ClientSettingsTests { @Test public void constructorWhenDefaultThenDefaultsAreSet() { ClientSettings clientSettings = new ClientSettings(); - assertThat(clientSettings.settings()).hasSize(1); + assertThat(clientSettings.settings()).hasSize(2); assertThat(clientSettings.requireProofKey()).isFalse(); + assertThat(clientSettings.requireUserConsent()).isFalse(); } @Test @@ -46,4 +47,10 @@ public class ClientSettingsTests { ClientSettings clientSettings = new ClientSettings().requireProofKey(true); assertThat(clientSettings.requireProofKey()).isTrue(); } + + @Test + public void requireUserConsentWhenTrueThenSet() { + ClientSettings clientSettings = new ClientSettings().requireUserConsent(true); + assertThat(clientSettings.requireUserConsent()).isTrue(); + } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 71dfd8d..dfa726a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -20,6 +20,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -34,6 +35,8 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -43,6 +46,7 @@ import org.springframework.util.StringUtils; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.nio.charset.StandardCharsets; import java.util.Set; import java.util.function.Consumer; @@ -122,19 +126,6 @@ public class OAuth2AuthorizationEndpointFilterTests { verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } - @Test - public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Exception { - String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; - MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); - request.setServletPath(requestUri); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - } - @Test public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception { doFilterWhenAuthorizationRequestInvalidParameterThenError( @@ -222,7 +213,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, @@ -236,7 +227,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, @@ -250,7 +241,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, @@ -258,6 +249,23 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token")); } + @Test + public void doFilterWhenAuthorizationRequestInvalidScopeThenInvalidScopeError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( + registeredClient, + OAuth2ParameterNames.SCOPE, + OAuth2ErrorCodes.INVALID_SCOPE, + DEFAULT_ERROR_URI, + request -> { + String scope = request.getParameter(OAuth2ParameterNames.SCOPE); + request.setParameter(OAuth2ParameterNames.SCOPE, scope + " invalid-scope"); + }); + } + @Test public void doFilterWhenPkceRequiredAndMissingCodeChallengeThenInvalidRequestError() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient() @@ -266,7 +274,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, PkceParameterNames.CODE_CHALLENGE, OAuth2ErrorCodes.INVALID_REQUEST, @@ -285,7 +293,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, PkceParameterNames.CODE_CHALLENGE, OAuth2ErrorCodes.INVALID_REQUEST, @@ -302,7 +310,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, PkceParameterNames.CODE_CHALLENGE, OAuth2ErrorCodes.INVALID_REQUEST, @@ -321,7 +329,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, PkceParameterNames.CODE_CHALLENGE_METHOD, OAuth2ErrorCodes.INVALID_REQUEST, @@ -338,7 +346,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, PkceParameterNames.CODE_CHALLENGE_METHOD, OAuth2ErrorCodes.INVALID_REQUEST, @@ -357,7 +365,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, PkceParameterNames.CODE_CHALLENGE_METHOD, OAuth2ErrorCodes.INVALID_REQUEST, @@ -374,7 +382,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + doFilterWhenAuthorizationRequestInvalidParameterThenRedirect( registeredClient, PkceParameterNames.CODE_CHALLENGE_METHOD, OAuth2ErrorCodes.INVALID_REQUEST, @@ -432,6 +440,10 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); assertThat(authorizationRequest).isNotNull(); + + Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); + assertThat(authorizedScopes).isEqualTo(authorizationRequest.getScopes()); + assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize"); assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); @@ -467,9 +479,19 @@ public class OAuth2AuthorizationEndpointFilterTests { verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization authorization = authorizationCaptor.getValue(); - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); - assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + String code = authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE); + assertThat(code).isNotNull(); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + assertThat(authorizationRequest).isNotNull(); + + Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); + assertThat(authorizedScopes).isEqualTo(authorizationRequest.getScopes()); + + assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId()); assertThat(authorizationRequest.getAdditionalParameters()) .size() .isEqualTo(2) @@ -478,6 +500,271 @@ public class OAuth2AuthorizationEndpointFilterTests { .containsEntry(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); } + @Test + public void doFilterWhenUserConsentRequiredAndAuthorizationRequestValidThenUserConsentResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(new ClientSettings().requireUserConsent(true)) + .build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + assertThat(response.getContentType().equals(new MediaType("text", "html", StandardCharsets.UTF_8).toString())); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + + verify(this.authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization authorization = authorizationCaptor.getValue(); + assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + + String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE); + assertThat(state).isNotNull(); + + Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); + assertThat(authorizedScopes).isNull(); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + assertThat(authorizationRequest).isNotNull(); + assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize"); + assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); + assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(registeredClient.getRedirectUris().iterator().next()); + assertThat(authorizationRequest.getScopes()).containsExactlyInAnyOrderElementsOf(registeredClient.getScopes()); + assertThat(authorizationRequest.getState()).isEqualTo("state"); + assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); + } + + @Test + public void doFilterWhenUserConsentRequestMissingStateThenInvalidRequestError() throws Exception { + doFilterWhenUserConsentRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(OAuth2ParameterNames.STATE)); + } + + @Test + public void doFilterWhenUserConsentRequestMultipleStateThenInvalidRequestError() throws Exception { + doFilterWhenUserConsentRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.STATE, "state-2")); + } + + @Test + public void doFilterWhenUserConsentRequestInvalidStateThenInvalidRequestError() throws Exception { + doFilterWhenUserConsentRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.setParameter(OAuth2ParameterNames.STATE, "invalid")); + } + + @Test + public void doFilterWhenUserConsentRequestNotAuthenticatedThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + this.authentication.setAuthenticated(false); + + doFilterWhenUserConsentRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> {}); + } + + @Test + public void doFilterWhenUserConsentRequestInvalidPrincipalThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + this.authentication = new TestingAuthenticationToken("other-principal", "password"); + this.authentication.setAuthenticated(true); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(this.authentication); + SecurityContextHolder.setContext(securityContext); + + doFilterWhenUserConsentRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> {}); + } + + @Test + public void doFilterWhenUserConsentRequestMissingClientIdThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(this.authentication.getName()) + .build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + doFilterWhenUserConsentRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID)); + } + + @Test + public void doFilterWhenUserConsentRequestMultipleClientIdThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(this.authentication.getName()) + .build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + doFilterWhenUserConsentRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2")); + } + + @Test + public void doFilterWhenUserConsentRequestInvalidClientIdThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(this.authentication.getName()) + .build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + doFilterWhenUserConsentRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid")); + } + + @Test + public void doFilterWhenUserConsentRequestDoesNotMatchClientThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + RegisteredClient otherRegisteredClient = TestRegisteredClients.registeredClient2().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(otherRegisteredClient) + .principalName(this.authentication.getName()) + .build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + doFilterWhenUserConsentRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> {}); + } + + @Test + public void doFilterWhenUserConsentRequestInvalidScopeThenInvalidScopeError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(this.authentication.getName()) + .build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + doFilterWhenUserConsentRequestInvalidParameterThenRedirect( + registeredClient, + OAuth2ParameterNames.SCOPE, + OAuth2ErrorCodes.INVALID_SCOPE, + DEFAULT_ERROR_URI, + request -> { + request.addParameter(OAuth2ParameterNames.SCOPE, "invalid-scope"); + }); + } + + @Test + public void doFilterWhenUserConsentRequestNotApprovedThenAccessDeniedError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(this.authentication.getName()) + .build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + doFilterWhenUserConsentRequestInvalidParameterThenRedirect( + registeredClient, + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.ACCESS_DENIED, + DEFAULT_ERROR_URI, + request -> request.removeParameter("consent_action")); + + verify(this.authorizationService).remove(eq(authorization)); + } + + @Test + public void doFilterWhenUserConsentRequestApprovedThenAuthorizationResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(this.authentication.getName()) + .build(); + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + .thenReturn(authorization); + + MockHttpServletRequest request = createUserConsentRequest(registeredClient); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + + verify(this.authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(updatedAuthorization.getAccessToken()).isNotNull(); + assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); + assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)).isNotNull(); + assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)) + .isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)); + assertThat(updatedAuthorization.>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES)) + .isEqualTo(registeredClient.getScopes()); + } + private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, String parameterName, String errorCode) throws Exception { doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {}); @@ -486,7 +773,36 @@ public class OAuth2AuthorizationEndpointFilterTests { private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, String parameterName, String errorCode, Consumer requestConsumer) throws Exception { - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + doFilterWhenRequestInvalidParameterThenError(createAuthorizationRequest(registeredClient), + parameterName, errorCode, requestConsumer); + } + + private void doFilterWhenAuthorizationRequestInvalidParameterThenRedirect(RegisteredClient registeredClient, + String parameterName, String errorCode, String errorUri, + Consumer requestConsumer) throws Exception { + + doFilterWhenRequestInvalidParameterThenRedirect(createAuthorizationRequest(registeredClient), + parameterName, errorCode, errorUri, requestConsumer); + } + + private void doFilterWhenUserConsentRequestInvalidParameterThenError(RegisteredClient registeredClient, + String parameterName, String errorCode, Consumer requestConsumer) throws Exception { + + doFilterWhenRequestInvalidParameterThenError(createUserConsentRequest(registeredClient), + parameterName, errorCode, requestConsumer); + } + + private void doFilterWhenUserConsentRequestInvalidParameterThenRedirect(RegisteredClient registeredClient, + String parameterName, String errorCode, String errorUri, + Consumer requestConsumer) throws Exception { + + doFilterWhenRequestInvalidParameterThenRedirect(createUserConsentRequest(registeredClient), + parameterName, errorCode, errorUri, requestConsumer); + } + + private void doFilterWhenRequestInvalidParameterThenError(MockHttpServletRequest request, + String parameterName, String errorCode, Consumer requestConsumer) throws Exception { + requestConsumer.accept(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -499,11 +815,10 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName); } - private void doFilterWhenAuthorizationRequestInvalidParameterThenRedirected(RegisteredClient registeredClient, + private void doFilterWhenRequestInvalidParameterThenRedirect(MockHttpServletRequest request, String parameterName, String errorCode, String errorUri, Consumer requestConsumer) throws Exception { - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); requestConsumer.accept(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -541,4 +856,19 @@ public class OAuth2AuthorizationEndpointFilterTests { request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); } + + private static MockHttpServletRequest createUserConsentRequest(RegisteredClient registeredClient) { + String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + for (String scope : registeredClient.getScopes()) { + request.addParameter(OAuth2ParameterNames.SCOPE, scope); + } + request.addParameter("consent_action", "approve"); + + return request; + } } diff --git a/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java b/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java index fde42af..226e9ce 100644 --- a/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java +++ b/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java @@ -29,6 +29,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.config.ClientSettings; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import java.util.UUID; @@ -53,6 +54,7 @@ public class AuthorizationServerConfig { .redirectUri("http://localhost:8080/authorized") .scope("message.read") .scope("message.write") + .clientSettings(new ClientSettings().requireUserConsent(true)) .build(); return new InMemoryRegisteredClientRepository(registeredClient); } diff --git a/samples/boot/oauth2-integration/client/src/main/java/sample/web/AuthorizationController.java b/samples/boot/oauth2-integration/client/src/main/java/sample/web/AuthorizationController.java index 3cccbc3..7a5d39b 100644 --- a/samples/boot/oauth2-integration/client/src/main/java/sample/web/AuthorizationController.java +++ b/samples/boot/oauth2-integration/client/src/main/java/sample/web/AuthorizationController.java @@ -18,11 +18,16 @@ package sample.web; import org.springframework.beans.factory.annotation.Value; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.stereotype.Controller; import org.springframework.ui.Model; +import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.reactive.function.client.WebClient; +import javax.servlet.http.HttpServletRequest; + import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; @@ -58,6 +63,22 @@ public class AuthorizationController { return "index"; } + // '/authorized' is the registered 'redirect_uri' for authorization_code + @GetMapping(value = "/authorized", params = OAuth2ParameterNames.ERROR) + public String authorizationFailed(Model model, HttpServletRequest request) { + String errorCode = request.getParameter(OAuth2ParameterNames.ERROR); + if (StringUtils.hasText(errorCode)) { + model.addAttribute("error", + new OAuth2Error( + errorCode, + request.getParameter(OAuth2ParameterNames.ERROR_DESCRIPTION), + request.getParameter(OAuth2ParameterNames.ERROR_URI)) + ); + } + + return "index"; + } + @GetMapping(value = "/authorize", params = "grant_type=client_credentials") public String clientCredentialsGrant(Model model) { diff --git a/samples/boot/oauth2-integration/client/src/main/resources/templates/index.html b/samples/boot/oauth2-integration/client/src/main/resources/templates/index.html index 6c8520c..edf9d32 100644 --- a/samples/boot/oauth2-integration/client/src/main/resources/templates/index.html +++ b/samples/boot/oauth2-integration/client/src/main/resources/templates/index.html @@ -19,6 +19,10 @@
+

Authorize the client using grant_type: