From ab090445b33155aeced14dc92fecca51375ca4ea Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Mon, 22 Jun 2020 21:35:01 +0200 Subject: [PATCH] Implement Proof Key for Code Exchange (PKCE) RFC 7636 See https://tools.ietf.org/html/rfc7636 Closes gh-45 --- .../OAuth2AccessTokenAuthenticationToken.java | 4 +- ...thorizationCodeAuthenticationProvider.java | 88 ++++- ...2AuthorizationCodeAuthenticationToken.java | 42 ++- .../client/RegisteredClient.java | 1 - .../OAuth2AuthorizationEndpointFilter.java | 37 +- .../web/OAuth2TokenEndpointFilter.java | 39 +- .../OAuth2AuthorizationCodeGrantTests.java | 2 + .../server/authorization/OAuth2PkceTests.java | 178 +++++++++ .../TestOAuth2Authorizations.java | 7 + ...zationCodeAuthenticationProviderTests.java | 337 +++++++++++++++++- ...orizationCodeAuthenticationTokenTests.java | 47 ++- ...Auth2AuthorizationEndpointFilterTests.java | 232 ++++++++++++ .../web/OAuth2TokenEndpointFilterTests.java | 49 +++ 13 files changed, 1008 insertions(+), 55 deletions(-) create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java index dffe329..e102627 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java @@ -38,8 +38,8 @@ import java.util.Collections; */ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID; - private final RegisteredClient registeredClient; - private final Authentication clientPrincipal; + private RegisteredClient registeredClient; + private Authentication clientPrincipal; private final OAuth2AccessToken accessToken; /** diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 6646b17..6bb7085 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -23,6 +23,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.jose.JoseHeader; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; @@ -33,15 +34,20 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.net.MalformedURLException; import java.net.URI; import java.net.URL; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Base64; import java.util.Collections; /** @@ -85,29 +91,30 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica (OAuth2AuthorizationCodeAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientPrincipal = null; + RegisteredClient registeredClient = null; if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) { clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal(); - } - if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) { + registeredClient = clientPrincipal.getRegisteredClient(); + } else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) { + // When the principal is a string, it is the clientId, REQUIRED for public clients + String clientId = authorizationCodeAuthentication.getClientId(); + registeredClient = this.registeredClientRepository.findByClientId(clientId); + if (registeredClient == null) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } + } else { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } - // TODO Authenticate public client - // A client MAY use the "client_id" request parameter to identify itself - // when sending requests to the token endpoint. - // In the "authorization_code" "grant_type" request to the token endpoint, - // an unauthenticated client MUST send its "client_id" to prevent itself - // from inadvertently accepting a code intended for a client with a different "client_id". - // This protects the client from substitution of the authentication code. + if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } OAuth2Authorization authorization = this.authorizationService.findByToken( authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE); if (authorization == null) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } - if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); - } OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); @@ -116,6 +123,35 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } + if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + + + String codeChallenge; + Object codeChallengeParameter = authorizationRequest + .getAdditionalParameters() + .get(PkceParameterNames.CODE_CHALLENGE); + + if (codeChallengeParameter != null) { + codeChallenge = (String) codeChallengeParameter; + + String codeChallengeMethod = (String) authorizationRequest + .getAdditionalParameters() + .get(PkceParameterNames.CODE_CHALLENGE_METHOD); + + String codeVerifier = (String) authorizationCodeAuthentication + .getAdditionalParameters() + .get(PkceParameterNames.CODE_VERIFIER); + + if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + } else if (registeredClient.getClientSettings().requireProofKey()){ + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); // TODO Allow configuration for issuer claim @@ -130,7 +166,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() .issuer(issuer) .subject(authorization.getPrincipalName()) - .audience(Collections.singletonList(clientPrincipal.getRegisteredClient().getClientId())) + .audience(Collections.singletonList(registeredClient.getClientId())) .issuedAt(issuedAt) .expiresAt(expiresAt) .notBefore(issuedAt) @@ -148,8 +184,30 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica .build(); this.authorizationService.save(authorization); - return new OAuth2AccessTokenAuthenticationToken( - clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken); + return clientPrincipal != null ? + new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken) : + new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken); + } + + private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) { + if (codeVerifier == null) { + return false; + } else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) { + return codeVerifier.equals(codeChallenge); + } else if ("S256".equals(codeChallengeMethod)) { + try { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII)); + String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + return codeChallenge.equals(encodedVerifier); + } catch (NoSuchAlgorithmException e) { + // It is unlikely that SHA-256 is not available on the server. If it is not available, + // there will likely be bigger issues as well. We default to SERVER_ERROR. + } + } + + // Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR)); } @Override diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java index fab21e8..e2ab8d8 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java @@ -22,6 +22,7 @@ import org.springframework.security.core.SpringSecurityCoreVersion2; import org.springframework.util.Assert; import java.util.Collections; +import java.util.Map; /** * An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant. @@ -35,10 +36,11 @@ import java.util.Collections; */ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID; - private String code; + private final String code; private Authentication clientPrincipal; - private String clientId; - private String redirectUri; + private final String clientId; + private final String redirectUri; + private final Map additionalParameters; /** * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters. @@ -46,15 +48,24 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti * @param code the authorization code * @param clientPrincipal the authenticated client principal * @param redirectUri the redirect uri + * @param additionalParameters the additional parameters */ public OAuth2AuthorizationCodeAuthenticationToken(String code, - Authentication clientPrincipal, @Nullable String redirectUri) { + Authentication clientPrincipal, @Nullable String redirectUri, + Map additionalParameters) { super(Collections.emptyList()); Assert.hasText(code, "code cannot be empty"); Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); this.code = code; this.clientPrincipal = clientPrincipal; this.redirectUri = redirectUri; + this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap()); + + if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) { + this.clientId = (String) this.clientPrincipal.getPrincipal(); + } else { + this.clientId = null; + } } /** @@ -63,15 +74,18 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti * @param code the authorization code * @param clientId the client identifier * @param redirectUri the redirect uri + * @param additionalParameters the additional parameters */ public OAuth2AuthorizationCodeAuthenticationToken(String code, - String clientId, @Nullable String redirectUri) { + String clientId, @Nullable String redirectUri, + Map additionalParameters) { super(Collections.emptyList()); Assert.hasText(code, "code cannot be empty"); Assert.hasText(clientId, "clientId cannot be empty"); this.code = code; this.clientId = clientId; this.redirectUri = redirectUri; + this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap()); } @Override @@ -101,4 +115,22 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti public @Nullable String getRedirectUri() { return this.redirectUri; } + + /** + * Returns the additional parameters + * + * @return the additional parameters + */ + public Map getAdditionalParameters() { + return this.additionalParameters; + } + + /** + * Returns the client id + * + * @return the client id + */ + public @Nullable String getClientId() { + return this.clientId; + } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java index a03930f..a94da6e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java @@ -367,7 +367,6 @@ public class RegisteredClient implements Serializable { Assert.hasText(this.clientId, "clientId cannot be empty"); Assert.notEmpty(this.authorizationGrantTypes, "authorizationGrantTypes cannot be empty"); if (this.authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { - Assert.hasText(this.clientSecret, "clientSecret cannot be empty"); Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty"); } if (CollectionUtils.isEmpty(this.clientAuthenticationMethods)) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 57e8f12..ab6692c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -78,6 +79,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { private final RequestMatcher authorizationEndpointMatcher; private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1"; /** * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters. @@ -174,6 +176,34 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { return; } + // code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE) + String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE); + if (StringUtils.hasText(codeChallenge)) { + if (parameters.get(PkceParameterNames.CODE_CHALLENGE).size() != 1) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI); + sendErrorResponse(request, response, error, stateParameter, redirectUri); + return; + } + + if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null && + parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI); + sendErrorResponse(request, response, error, stateParameter, redirectUri); + return; + } + + String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD); + if (codeChallengeMethod != null && !Arrays.asList("plain", "S256").contains(codeChallengeMethod)) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI); + sendErrorResponse(request, response, error, stateParameter, redirectUri); + return; + } + } else if (registeredClient.getClientSettings().requireProofKey()) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI); + sendErrorResponse(request, response, error, stateParameter, redirectUri); + return; + } + // --------------- // The request is valid - ensure the resource owner is authenticated // --------------- @@ -245,8 +275,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { } private static OAuth2Error createError(String errorCode, String parameterName) { - return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, - "https://tools.ietf.org/html/rfc6749#section-4.1.2.1"); + return createError(errorCode, parameterName, "https://tools.ietf.org/html/rfc6749#section-4.1.2.1"); + } + + private static OAuth2Error createError(String errorCode, String parameterName, String errorUri) { + return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri); } private static boolean isPrincipalAuthenticated(Authentication principal) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 3680d66..84d9bca 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -30,11 +30,13 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -54,6 +56,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, @@ -198,14 +201,22 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); // client_id (REQUIRED) - String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); - Authentication clientPrincipal = null; - if (StringUtils.hasText(clientId)) { - if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { + Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + String clientId = null; + if (clientPrincipal == null || + !OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass())) { + clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); + if (!StringUtils.hasText(clientId) || + parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); } - } else { - clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + + // code_verifier (REQUIRED for public clients) + String codeVerifier = parameters.getFirst(PkceParameterNames.CODE_VERIFIER); + if (!StringUtils.hasText(codeVerifier) || + parameters.get(PkceParameterNames.CODE_VERIFIER).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_VERIFIER); + } } // code (REQUIRED) @@ -223,9 +234,19 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); } - return clientPrincipal != null ? - new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) : - new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri); + Map additionalParameters = parameters + .entrySet() + .stream() + .filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) && + !e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) && + !e.getKey().equals(OAuth2ParameterNames.CODE) && + !e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI)) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0))); + + + return clientId != null ? + new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri, additionalParameters) : + new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 5a655f1..c3e3753 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -31,6 +31,7 @@ import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -169,6 +170,7 @@ public class OAuth2AuthorizationCodeGrantTests { parameters.set(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); parameters.set(OAuth2ParameterNames.STATE, "state"); + parameters.set(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); return parameters; } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java new file mode 100644 index 0000000..21da3b0 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.security.config.annotation.web.WebSecurityConfigurer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.builders.WebSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.config.ClientSettings; +import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; +import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +public class OAuth2PkceTests { + private static RegisteredClientRepository registeredClientRepository; + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Autowired + private MockMvc mvc; + + @BeforeClass + public static void init() { + registeredClientRepository = mock(RegisteredClientRepository.class); + } + + @Before + public void setup() { + reset(registeredClientRepository); + } + + @Test + public void requestWhenTokenRequestNotAuthenticatedAndPkceParamatersProvidedThenRedirectToClient() throws Exception { + // See RFC 7636: Appendix B. Example for the S256 code_challenge_method + // https://tools.ietf.org/html/rfc7636#appendix-B + final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + ClientSettings settings = new ClientSettings(); + RegisteredClient registeredClient = TestRegisteredClients + .registeredClient() + .clientSettings(settings.requireProofKey(true)) + .build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .params(getAuthorizationRequestParameters(registeredClient)) + .param("code_challenge", S256_CODE_CHALLENGE) + .param("code_challenge_method", "S256") + .with(user("user"))) + .andExpect(status().is3xxRedirection()) + .andReturn(); + + assertThat(mvcResult.getResponse().getRedirectedUrl()) + .doesNotContain("error=") + .contains("code="); + + String authorizationCode = UriUtils.decode(UriComponentsBuilder.fromHttpUrl(mvcResult.getResponse().getRedirectedUrl()) + .build() + .getQueryParams() + .getFirst("code"), "utf-8"); + + this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorizationCode)) + .param("code_verifier", S256_CODE_VERIFIER)) + .andExpect(status().is2xxSuccessful()) + .andExpect(jsonPath("$.access_token").isNotEmpty()); + } + + private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); + parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); + parameters.set(OAuth2ParameterNames.STATE, "state"); + return parameters; + } + + private static MultiValueMap getTokenRequestParameters(RegisteredClient registeredClient, + String authorizationCode) { + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + parameters.set(OAuth2ParameterNames.CODE, authorizationCode); + parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); + parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + return parameters; + } + + @EnableWebSecurity + static class AuthorizationServerConfiguration { + + @Bean + RegisteredClientRepository registeredClientRepository() { + return registeredClientRepository; + } + + @Bean + OAuth2AuthorizationService authorizationService() { + return new InMemoryOAuth2AuthorizationService(); + } + + @Bean + KeyManager keyManager() { + return new StaticKeyGeneratingKeyManager(); + } + + @Bean + WebSecurityConfigurer defaultOAuth2AuthorizationServerSecurity() { + return new WebSecurityConfigurerAdapter() { + @Override + public void configure(HttpSecurity http) throws Exception { + http + .authorizeRequests() + .anyRequest() + .permitAll() + .and() + .csrf() + .disable() + .apply(new OAuth2AuthorizationServerConfigurer<>()); + } + }; + } + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 00260bf..dd66ab3 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -21,6 +21,8 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import java.time.Instant; +import java.util.Collections; +import java.util.Map; /** * @author Joe Grandja @@ -32,12 +34,17 @@ public class TestOAuth2Authorizations { } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient) { + return authorization(registeredClient, Collections.emptyMap()); + } + + public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, Map additionalParameters) { OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://provider.com/oauth2/authorize") .clientId(registeredClient.getClientId()) .redirectUri(registeredClient.getRedirectUris().iterator().next()) + .additionalParameters(additionalParameters) .state("state") .build(); return OAuth2Authorization.withRegisteredClient(registeredClient) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 52c5592..f38950b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -22,6 +22,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.jose.JoseHeaderNames; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; @@ -35,6 +36,11 @@ import org.springframework.security.oauth2.server.authorization.client.InMemoryR import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.config.ClientSettings; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -53,7 +59,19 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class OAuth2AuthorizationCodeAuthenticationProviderTests { + private final String PLAIN_CODE_CHALLENGE = "pkce-key"; + private final String PLAIN_CODE_VERIFIER = PLAIN_CODE_CHALLENGE; + + // See RFC 7636: Appendix B. Example for the S256 code_challenge_method + // https://tools.ietf.org/html/rfc7636#appendix-B + private final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + private final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + + private final String AUTHORIZATION_CODE = "code"; + private RegisteredClient registeredClient; + private RegisteredClient otherRegisteredClient; + private RegisteredClient registeredClientRequiresProofKey; private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; @@ -62,7 +80,17 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Before public void setUp() { this.registeredClient = TestRegisteredClients.registeredClient().build(); - this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient); + this.otherRegisteredClient = TestRegisteredClients.registeredClient2().build(); + this.registeredClientRequiresProofKey = TestRegisteredClients.registeredClient() + .id("registration-3") + .clientId("client-3") + .clientSettings(new ClientSettings().requireProofKey(true)) + .build(); + this.registeredClientRepository = new InMemoryRegisteredClientRepository( + this.registeredClient, + this.otherRegisteredClient, + this.registeredClientRequiresProofKey + ); this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( @@ -100,7 +128,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -113,7 +141,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -125,7 +153,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -142,7 +170,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( TestRegisteredClients.registeredClient2().build()); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -160,7 +188,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid"); + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid", null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -178,16 +206,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri()); + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null); - Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); - Jwt jwt = Jwt.withTokenValue("token") - .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .build(); - when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -201,4 +222,290 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(updatedAuthorization.getAccessToken()).isNotNull(); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); } + + @Test + public void authenticateWhenRequireProofKeyAndMissingPkceCodeChallengeInAuthorizationRequestThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClientRequiresProofKey).build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + registeredClientRequiresProofKey.getClientId(), + authorizationRequest.getRedirectUri(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) + ); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenRequireProofKeyAndUnsupportedCodeChallengeMethodInAuthorizationRequestThenThrowOAuth2AuthenticationException() { + Map pkceParameters = new HashMap<>(); + pkceParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); + // This should never happen: the Authorization endpoint should not allow it + pkceParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method"); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClientRequiresProofKey, pkceParameters) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + registeredClientRequiresProofKey.getClientId(), + authorizationRequest.getRedirectUri(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) + ); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + } + + @Test + public void authenticateWhenPublicClientAndClientIdNotMatchingThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + otherRegisteredClient.getClientId(), + authorizationRequest.getRedirectUri(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) + ); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenPublicClientAndUnknownClientIdThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + "invalid-client-id", + authorizationRequest.getRedirectUri(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_CHALLENGE) + ); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenPublicClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + authorizationRequest.getClientId(), + authorizationRequest.getRedirectUri(), + null + ); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenPrivateClientAndRequireProofKeyAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + clientPrincipal, + authorizationRequest.getRedirectUri(), + null + ); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenPublicClientAndPlainMethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier"); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenPublicClientAndS256MethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersS256()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier"); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenPublicClientAndPlainMethodAndValidCodeVerifierThenReturnAccessToken() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); + assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + assertThat(updatedAuthorization.getAccessToken()).isNotNull(); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); + } + + @Test + public void authenticateWhenPublicClientAndNoMethodThenDefaultToPlainAndReturnAccessToken() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, Collections.singletonMap(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE)) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); + assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + assertThat(updatedAuthorization.getAccessToken()).isNotNull(); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); + } + + + @Test + public void authenticateWhenPublicClientAndS256MethodAndValidCodeVerifierThenReturnAccessToken() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, getPkceAuthorizationParametersS256()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(S256_CODE_VERIFIER); + + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); + assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + assertThat(updatedAuthorization.getAccessToken()).isNotNull(); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); + } + + private Map getPkceAuthorizationParametersPlain() { + Map additionalParameters = new HashMap<>(); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); + return additionalParameters; + } + + private Map getPkceAuthorizationParametersS256() { + Map additionalParameters = new HashMap<>(); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE); + return additionalParameters; + } + + private OAuth2AuthorizationCodeAuthenticationToken makeAuthorizationCodeAuthenticationToken(String codeVerifier) { + return new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + registeredClient.getClientId(), + registeredClient.getRedirectUris().iterator().next(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, codeVerifier) + ); + } + + private Jwt createJwt() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); + return Jwt.withTokenValue("token") + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .build(); + } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java index e2977a3..a72840c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java @@ -19,6 +19,9 @@ import org.junit.Test; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import java.util.Collections; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -29,28 +32,30 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; */ public class OAuth2AuthorizationCodeAuthenticationTokenTests { private String code = "code"; + private String clientPrincipalClientId = "clientPrincipal.clientId"; private OAuth2ClientAuthenticationToken clientPrincipal = - new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); + new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().clientId(clientPrincipalClientId).build()); private String clientId = "clientId"; private String redirectUri = "redirectUri"; + private Map additonalParams = Collections.singletonMap("some_key", "some_value"); @Test public void constructorWhenCodeNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("code cannot be empty"); } @Test public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientPrincipal cannot be null"); } @Test public void constructorWhenClientIdNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientId cannot be empty"); } @@ -58,20 +63,50 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests { @Test public void constructorWhenClientPrincipalProvidedThenCreated() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.code, this.clientPrincipal, this.redirectUri); + this.code, this.clientPrincipal, this.redirectUri, this.additonalParams); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getCode()).isEqualTo(this.code); assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri); + assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams); } @Test public void constructorWhenClientIdProvidedThenCreated() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.code, this.clientId, this.redirectUri); + this.code, this.clientId, this.redirectUri, this.additonalParams); assertThat(authentication.getPrincipal()).isEqualTo(this.clientId); assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getCode()).isEqualTo(this.code); assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri); + assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams); + } + + @Test + public void getAdditionalParamsIsImmutableMap() { + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + this.code, this.clientId, this.redirectUri, this.additonalParams); + assertThatThrownBy(() -> authentication.getAdditionalParameters().put("another_key", 1)) + .isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> authentication.getAdditionalParameters().remove("some_key")) + .isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> authentication.getAdditionalParameters().clear()) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + public void getClientIdFromClientId() { + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + this.code, this.clientId, this.redirectUri, this.additonalParams); + + assertThat(authentication.getClientId()).isEqualTo(this.clientId); + } + + @Test + public void getClientIdFromOAuth2ClientAuthenticationTokenPrincipal() { + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + this.code, this.clientPrincipal, this.redirectUri, this.additonalParams); + + assertThat(authentication.getClientId()).isEqualTo(this.clientPrincipalClientId); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 3f4a2a6..3c5c03e 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -30,12 +30,14 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.config.ClientSettings; import org.springframework.util.StringUtils; import javax.servlet.FilterChain; @@ -280,6 +282,176 @@ public class OAuth2AuthorizationEndpointFilterTests { "state=state"); } + @Test + public void doFilterWhenProofKeyRequiredAndMissingPkceCodeChallengeThenThrowError() throws Exception { + RegisteredClient registeredClient = createClientRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + request.removeParameter(PkceParameterNames.CODE_CHALLENGE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20code_challenge&" + + "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + + "state=state"); + } + + @Test + public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeThenThrowError() throws Exception { + RegisteredClient registeredClient = createClientRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20code_challenge&" + + "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + + "state=state"); + } + + @Test + public void doFilterWhenProofKeyNotRequiredClientAndMultiplePkceCodeChallengeThenThrowError() throws Exception { + RegisteredClient registeredClient = createClientDoNotRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20code_challenge&" + + "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + + "state=state"); + + } + + @Test + public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeMethodThenThrowError() throws Exception { + RegisteredClient registeredClient = createClientRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + + "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + + "state=state"); + } + + @Test + public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAnMultiplePkceCodeChallengeMethodThenThrowError() throws Exception { + RegisteredClient registeredClient = createClientDoNotRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + + "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + + "state=state"); + } + + @Test + public void doFilterWhenProofKeyRequiredAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception { + RegisteredClient registeredClient = createClientRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + + "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + + "state=state"); + } + + @Test + public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception { + RegisteredClient registeredClient = createClientDoNotRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + + "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + + "state=state"); + + } + @Test public void doFilterWhenAuthorizationRequestValidNotAuthenticatedThenContinueChainToCommenceAuthentication() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -337,6 +509,40 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); } + @Test + public void doFilterWhenProofKeyRequiredAndAuthorizationRequestValidThenAuthorizationResponse() throws Exception { + RegisteredClient registeredClient = createClientRequireProofKey(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request = addPkceParameters(request); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + + verify(this.authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization authorization = authorizationCaptor.getValue(); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId()); + + assertThat(authorizationRequest.getAdditionalParameters()) + .size() + .isEqualTo(2) + .returnToMap() + .containsEntry(PkceParameterNames.CODE_CHALLENGE, "code-challenge") + .containsEntry(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + } + private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, String parameterName, String errorCode) throws Exception { doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {}); @@ -374,4 +580,30 @@ public class OAuth2AuthorizationEndpointFilterTests { return request; } + + private static MockHttpServletRequest addPkceParameters(MockHttpServletRequest request) { + request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); + request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + + return request; + } + + private RegisteredClient createClientRequireProofKey() { + ClientSettings clientSettings = new ClientSettings(); + clientSettings.requireProofKey(true); + + return TestRegisteredClients.registeredClient() + .clientSettings(clientSettings) + .build(); + } + + private RegisteredClient createClientDoNotRequireProofKey() { + ClientSettings clientSettings = new ClientSettings(); + clientSettings.requireProofKey(false); + + return TestRegisteredClients.registeredClient() + .clientSettings(clientSettings) + .build(); + } + } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index 0ecd147..67b8bef 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -178,6 +179,12 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); request.removeParameter(OAuth2ParameterNames.CODE); @@ -188,6 +195,12 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); request.addParameter(OAuth2ParameterNames.CODE, "code-2"); @@ -198,6 +211,12 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); @@ -206,6 +225,34 @@ public class OAuth2TokenEndpointFilterTests { OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, request); } + @Test + public void doFilterWhenTokenRequestNotAuthenticatedAndMissingCodeVerifierThenInvalidRequestError() throws Exception { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + SecurityContextHolder.setContext(securityContext); + + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com"); + + doFilterWhenTokenRequestInvalidParameterThenError( + PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request); + } + + @Test + public void doFilterWhenTokenRequestNotAuthenticatedAndMultipleCodeVerifierThenInvalidRequestError() throws Exception { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + SecurityContextHolder.setContext(securityContext); + + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(PkceParameterNames.CODE_VERIFIER, "one-verifier"); + request.addParameter(PkceParameterNames.CODE_VERIFIER, "two-verifiers"); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com"); + + doFilterWhenTokenRequestInvalidParameterThenError( + PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request); + } + @Test public void doFilterWhenAuthorizationCodeTokenRequestValidThenAccessTokenResponse() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -359,6 +406,8 @@ public class OAuth2TokenEndpointFilterTests { request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]); + // The client does not need to send the client ID param, but we are resilient in case they do + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); return request; }