diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java index e102627..dffe329 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java @@ -38,8 +38,8 @@ import java.util.Collections; */ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID; - private RegisteredClient registeredClient; - private Authentication clientPrincipal; + private final RegisteredClient registeredClient; + private final Authentication clientPrincipal; private final OAuth2AccessToken accessToken; /** diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 6bb7085..6e4f924 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -23,8 +23,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.jose.JoseHeader; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; @@ -39,12 +39,12 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; import java.net.MalformedURLException; import java.net.URI; import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Base64; @@ -54,6 +54,7 @@ import java.util.Collections; * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant. * * @author Joe Grandja + * @author Daniel Garnier-Moiroux * @since 0.0.1 * @see OAuth2AuthorizationCodeAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken @@ -91,12 +92,14 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica (OAuth2AuthorizationCodeAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientPrincipal = null; - RegisteredClient registeredClient = null; + RegisteredClient registeredClient; if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) { clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal(); + if (!clientPrincipal.isAuthenticated()) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } registeredClient = clientPrincipal.getRegisteredClient(); } else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) { - // When the principal is a string, it is the clientId, REQUIRED for public clients String clientId = authorizationCodeAuthentication.getClientId(); registeredClient = this.registeredClientRepository.findByClientId(clientId); if (registeredClient == null) { @@ -106,10 +109,6 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } - if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); - } - OAuth2Authorization authorization = this.authorizationService.findByToken( authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE); if (authorization == null) { @@ -118,24 +117,21 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); - if (StringUtils.hasText(authorizationRequest.getRedirectUri()) && - !authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); - } if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } + if (StringUtils.hasText(authorizationRequest.getRedirectUri()) && + !authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } - String codeChallenge; - Object codeChallengeParameter = authorizationRequest + // Validate PKCE parameters + String codeChallenge = (String) authorizationRequest .getAdditionalParameters() .get(PkceParameterNames.CODE_CHALLENGE); - - if (codeChallengeParameter != null) { - codeChallenge = (String) codeChallengeParameter; - + if (StringUtils.hasText(codeChallenge)) { String codeChallengeMethod = (String) authorizationRequest .getAdditionalParameters() .get(PkceParameterNames.CODE_CHALLENGE_METHOD); @@ -147,11 +143,10 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } - } else if (registeredClient.getClientSettings().requireProofKey()){ + } else if (registeredClient.getClientSettings().requireProofKey()) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); // TODO Allow configuration for issuer claim @@ -189,24 +184,22 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken); } - private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) { - if (codeVerifier == null) { + private static boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) { + if (!StringUtils.hasText(codeVerifier)) { return false; - } else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) { + } else if (!StringUtils.hasText(codeChallengeMethod) || "plain".equals(codeChallengeMethod)) { return codeVerifier.equals(codeChallenge); } else if ("S256".equals(codeChallengeMethod)) { try { MessageDigest md = MessageDigest.getInstance("SHA-256"); byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII)); String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest); - return codeChallenge.equals(encodedVerifier); - } catch (NoSuchAlgorithmException e) { + return encodedVerifier.equals(codeChallenge); + } catch (NoSuchAlgorithmException ex) { // It is unlikely that SHA-256 is not available on the server. If it is not available, // there will likely be bigger issues as well. We default to SERVER_ERROR. } } - - // Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR)); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java index e2ab8d8..f93ac28 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java @@ -29,6 +29,7 @@ import java.util.Map; * * @author Joe Grandja * @author Madhu Bhat + * @author Daniel Garnier-Moiroux * @since 0.0.1 * @see AbstractAuthenticationToken * @see OAuth2AuthorizationCodeAuthenticationProvider @@ -37,7 +38,7 @@ import java.util.Map; public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID; private final String code; - private Authentication clientPrincipal; + private final Authentication clientPrincipal; private final String clientId; private final String redirectUri; private final Map additionalParameters; @@ -50,22 +51,21 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti * @param redirectUri the redirect uri * @param additionalParameters the additional parameters */ - public OAuth2AuthorizationCodeAuthenticationToken(String code, - Authentication clientPrincipal, @Nullable String redirectUri, - Map additionalParameters) { + public OAuth2AuthorizationCodeAuthenticationToken(String code, Authentication clientPrincipal, + @Nullable String redirectUri, @Nullable Map additionalParameters) { super(Collections.emptyList()); Assert.hasText(code, "code cannot be empty"); Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); this.code = code; this.clientPrincipal = clientPrincipal; + this.clientId = OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass()) ? + (String) this.clientPrincipal.getPrincipal() : + null; this.redirectUri = redirectUri; - this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap()); - - if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) { - this.clientId = (String) this.clientPrincipal.getPrincipal(); - } else { - this.clientId = null; - } + this.additionalParameters = Collections.unmodifiableMap( + additionalParameters != null ? + additionalParameters : + Collections.emptyMap()); } /** @@ -76,16 +76,19 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti * @param redirectUri the redirect uri * @param additionalParameters the additional parameters */ - public OAuth2AuthorizationCodeAuthenticationToken(String code, - String clientId, @Nullable String redirectUri, - Map additionalParameters) { + public OAuth2AuthorizationCodeAuthenticationToken(String code, String clientId, + @Nullable String redirectUri, @Nullable Map additionalParameters) { super(Collections.emptyList()); Assert.hasText(code, "code cannot be empty"); Assert.hasText(clientId, "clientId cannot be empty"); this.code = code; + this.clientPrincipal = null; this.clientId = clientId; this.redirectUri = redirectUri; - this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap()); + this.additionalParameters = Collections.unmodifiableMap( + additionalParameters != null ? + additionalParameters : + Collections.emptyMap()); } @Override @@ -107,6 +110,15 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti return this.code; } + /** + * Returns the client identifier + * + * @return the client identifier + */ + public @Nullable String getClientId() { + return this.clientId; + } + /** * Returns the redirect uri. * @@ -124,13 +136,4 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti public Map getAdditionalParameters() { return this.additionalParameters; } - - /** - * Returns the client id - * - * @return the client id - */ - public @Nullable String getClientId() { - return this.clientId; - } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index ab6692c..7c94fcc 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -61,6 +61,7 @@ import java.util.Set; * * @author Joe Grandja * @author Paurav Munshi + * @author Daniel Garnier-Moiroux * @since 0.0.1 * @see RegisteredClientRepository * @see OAuth2AuthorizationService @@ -74,12 +75,13 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { */ public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize"; + private static final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1"; + private final RegisteredClientRepository registeredClientRepository; private final OAuth2AuthorizationService authorizationService; private final RequestMatcher authorizationEndpointMatcher; private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); - private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1"; /** * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters. @@ -185,15 +187,16 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { return; } - if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null && - parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) { + String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD); + if (StringUtils.hasText(codeChallengeMethod) && + parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() != 1) { OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI); sendErrorResponse(request, response, error, stateParameter, redirectUri); return; } - String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD); - if (codeChallengeMethod != null && !Arrays.asList("plain", "S256").contains(codeChallengeMethod)) { + if (StringUtils.hasText(codeChallengeMethod) && + (!"S256".equals(codeChallengeMethod) && !"plain".equals(codeChallengeMethod))) { OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI); sendErrorResponse(request, response, error, stateParameter, redirectUri); return; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 84d9bca..ef14b48 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -80,6 +80,7 @@ import java.util.stream.Collectors; * * @author Joe Grandja * @author Madhu Bhat + * @author Daniel Garnier-Moiroux * @since 0.0.1 * @see AuthenticationManager * @see OAuth2AuthorizationService @@ -188,6 +189,12 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { throw new OAuth2AuthenticationException(error); } + private static boolean isClientAuthenticated(Authentication clientPrincipal) { + return clientPrincipal != null && + OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass()) && + clientPrincipal.isAuthenticated(); + } + private static class AuthorizationCodeAuthenticationConverter implements Converter { @Override @@ -200,25 +207,6 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // client_id (REQUIRED) - Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - String clientId = null; - if (clientPrincipal == null || - !OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass())) { - clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); - if (!StringUtils.hasText(clientId) || - parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); - } - - // code_verifier (REQUIRED for public clients) - String codeVerifier = parameters.getFirst(PkceParameterNames.CODE_VERIFIER); - if (!StringUtils.hasText(codeVerifier) || - parameters.get(PkceParameterNames.CODE_VERIFIER).size() != 1) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_VERIFIER); - } - } - // code (REQUIRED) String code = parameters.getFirst(OAuth2ParameterNames.CODE); if (!StringUtils.hasText(code) || @@ -234,6 +222,25 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); } + // client_id (REQUIRED) + // Required only if the client did not authenticate + String clientId = null; + Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + if (!isClientAuthenticated(clientPrincipal)) { + clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); + if (!StringUtils.hasText(clientId) || + parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); + } + + // code_verifier (REQUIRED for public clients) + String codeVerifier = parameters.getFirst(PkceParameterNames.CODE_VERIFIER); + if (!StringUtils.hasText(codeVerifier) || + parameters.get(PkceParameterNames.CODE_VERIFIER).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_VERIFIER); + } + } + Map additionalParameters = parameters .entrySet() .stream() diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index c3e3753..f35a8a3 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -19,6 +19,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; @@ -40,6 +41,7 @@ import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.config.ClientSettings; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; @@ -59,20 +61,29 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** * Integration tests for the OAuth 2.0 Authorization Code Grant. * * @author Joe Grandja + * @author Daniel Garnier-Moiroux */ public class OAuth2AuthorizationCodeGrantTests { + // See RFC 7636: Appendix B. Example for the S256 code_challenge_method + // https://tools.ietf.org/html/rfc7636#appendix-B + private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static KeyManager keyManager; @@ -162,6 +173,52 @@ public class OAuth2AuthorizationCodeGrantTests { verify(authorizationService).save(any()); } + @Test + public void requestWhenPublicClientWithPkceThenReturnAccessToken() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSecret(null) + .clientSettings(new ClientSettings().requireProofKey(true)) + .build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .params(getAuthorizationRequestParameters(registeredClient)) + .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .with(user("user"))) + .andExpect(status().is3xxRedirection()) + .andReturn(); + assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); + + verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization authorization = authorizationCaptor.getValue(); + + when(authorizationService.findByToken( + eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), + eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorization)) + .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) + .param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER) + .with(user("user"))) // TODO Remove after PKCE authentication is moved to OAuth2ClientAuthenticationProvider + .andExpect(status().isOk()) + .andExpect(jsonPath("$.access_token").isNotEmpty()); + + verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); + verify(authorizationService).findByToken( + eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), + eq(TokenType.AUTHORIZATION_CODE)); + verify(authorizationService, times(2)).save(any()); + } + private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); @@ -170,7 +227,6 @@ public class OAuth2AuthorizationCodeGrantTests { parameters.set(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); parameters.set(OAuth2ParameterNames.STATE, "state"); - parameters.set(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); return parameters; } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java deleted file mode 100644 index 21da3b0..0000000 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; - -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.context.annotation.Bean; -import org.springframework.security.config.annotation.web.WebSecurityConfigurer; -import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.builders.WebSecurity; -import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; -import org.springframework.security.config.test.SpringTestRule; -import org.springframework.security.crypto.keys.KeyManager; -import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; -import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.config.ClientSettings; -import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; -import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; -import org.springframework.test.web.servlet.MockMvc; -import org.springframework.test.web.servlet.MvcResult; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponentsBuilder; -import org.springframework.web.util.UriUtils; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.when; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; - -public class OAuth2PkceTests { - private static RegisteredClientRepository registeredClientRepository; - - @Rule - public final SpringTestRule spring = new SpringTestRule(); - - @Autowired - private MockMvc mvc; - - @BeforeClass - public static void init() { - registeredClientRepository = mock(RegisteredClientRepository.class); - } - - @Before - public void setup() { - reset(registeredClientRepository); - } - - @Test - public void requestWhenTokenRequestNotAuthenticatedAndPkceParamatersProvidedThenRedirectToClient() throws Exception { - // See RFC 7636: Appendix B. Example for the S256 code_challenge_method - // https://tools.ietf.org/html/rfc7636#appendix-B - final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; - final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; - - this.spring.register(AuthorizationServerConfiguration.class).autowire(); - - ClientSettings settings = new ClientSettings(); - RegisteredClient registeredClient = TestRegisteredClients - .registeredClient() - .clientSettings(settings.requireProofKey(true)) - .build(); - when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) - .param("code_challenge", S256_CODE_CHALLENGE) - .param("code_challenge_method", "S256") - .with(user("user"))) - .andExpect(status().is3xxRedirection()) - .andReturn(); - - assertThat(mvcResult.getResponse().getRedirectedUrl()) - .doesNotContain("error=") - .contains("code="); - - String authorizationCode = UriUtils.decode(UriComponentsBuilder.fromHttpUrl(mvcResult.getResponse().getRedirectedUrl()) - .build() - .getQueryParams() - .getFirst("code"), "utf-8"); - - this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) - .params(getTokenRequestParameters(registeredClient, authorizationCode)) - .param("code_verifier", S256_CODE_VERIFIER)) - .andExpect(status().is2xxSuccessful()) - .andExpect(jsonPath("$.access_token").isNotEmpty()); - } - - private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { - MultiValueMap parameters = new LinkedMultiValueMap<>(); - parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); - parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); - parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); - parameters.set(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); - parameters.set(OAuth2ParameterNames.STATE, "state"); - return parameters; - } - - private static MultiValueMap getTokenRequestParameters(RegisteredClient registeredClient, - String authorizationCode) { - MultiValueMap parameters = new LinkedMultiValueMap<>(); - parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - parameters.set(OAuth2ParameterNames.CODE, authorizationCode); - parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); - parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); - return parameters; - } - - @EnableWebSecurity - static class AuthorizationServerConfiguration { - - @Bean - RegisteredClientRepository registeredClientRepository() { - return registeredClientRepository; - } - - @Bean - OAuth2AuthorizationService authorizationService() { - return new InMemoryOAuth2AuthorizationService(); - } - - @Bean - KeyManager keyManager() { - return new StaticKeyGeneratingKeyManager(); - } - - @Bean - WebSecurityConfigurer defaultOAuth2AuthorizationServerSecurity() { - return new WebSecurityConfigurerAdapter() { - @Override - public void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest() - .permitAll() - .and() - .csrf() - .disable() - .apply(new OAuth2AuthorizationServerConfigurer<>()); - } - }; - } - } -} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index dd66ab3..70a0305 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -26,6 +26,7 @@ import java.util.Map; /** * @author Joe Grandja + * @author Daniel Garnier-Moiroux */ public class TestOAuth2Authorizations { @@ -37,14 +38,15 @@ public class TestOAuth2Authorizations { return authorization(registeredClient, Collections.emptyMap()); } - public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, Map additionalParameters) { + public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, + Map authorizationRequestAdditionalParameters) { OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://provider.com/oauth2/authorize") .clientId(registeredClient.getClientId()) .redirectUri(registeredClient.getRedirectUris().iterator().next()) - .additionalParameters(additionalParameters) + .additionalParameters(authorizationRequestAdditionalParameters) .state("state") .build(); return OAuth2Authorization.withRegisteredClient(registeredClient) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index f38950b..efddc09 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -38,13 +38,12 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.ClientSettings; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.time.Instant; -import java.time.temporal.ChronoUnit; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; @@ -57,17 +56,18 @@ import static org.mockito.Mockito.when; * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}. * * @author Joe Grandja + * @author Daniel Garnier-Moiroux */ public class OAuth2AuthorizationCodeAuthenticationProviderTests { - private final String PLAIN_CODE_CHALLENGE = "pkce-key"; - private final String PLAIN_CODE_VERIFIER = PLAIN_CODE_CHALLENGE; + private static final String PLAIN_CODE_VERIFIER = "pkce-key"; + private static final String PLAIN_CODE_CHALLENGE = PLAIN_CODE_VERIFIER; // See RFC 7636: Appendix B. Example for the S256 code_challenge_method // https://tools.ietf.org/html/rfc7636#appendix-B - private final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; - private final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; - private final String AUTHORIZATION_CODE = "code"; + private static final String AUTHORIZATION_CODE = "code"; private RegisteredClient registeredClient; private RegisteredClient otherRegisteredClient; @@ -128,7 +128,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -141,7 +141,32 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenPublicClientAndInvalidClientIdThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(this.registeredClient, createPkceParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + "invalid-client-id", + authorizationRequest.getRedirectUri(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_CHALLENGE) + ); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -153,7 +178,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -164,13 +189,37 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenCodeIssuedToAnotherClientThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.authorizationService.findByToken(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( TestRegisteredClients.registeredClient2().build()); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null); + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenPublicClientAndClientIdNotMatchThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(this.registeredClient, createPkceParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + this.otherRegisteredClient.getClientId(), + authorizationRequest.getRedirectUri(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) + ); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -181,14 +230,14 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.authorizationService.findByToken(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid", null); + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid", null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -199,14 +248,14 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenValidCodeThenReturnAccessToken() { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.authorizationService.findByToken(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null); + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); @@ -224,20 +273,22 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { } @Test - public void authenticateWhenRequireProofKeyAndMissingPkceCodeChallengeInAuthorizationRequestThenThrowOAuth2AuthenticationException() { - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClientRequiresProofKey).build(); + public void authenticateWhenRequireProofKeyAndMissingCodeChallengeThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(this.registeredClientRequiresProofKey) + .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( AUTHORIZATION_CODE, - registeredClientRequiresProofKey.getClientId(), + this.registeredClientRequiresProofKey.getClientId(), authorizationRequest.getRedirectUri(), Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) ); - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -245,91 +296,16 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); } - @Test - public void authenticateWhenRequireProofKeyAndUnsupportedCodeChallengeMethodInAuthorizationRequestThenThrowOAuth2AuthenticationException() { - Map pkceParameters = new HashMap<>(); - pkceParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); - // This should never happen: the Authorization endpoint should not allow it - pkceParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method"); - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClientRequiresProofKey, pkceParameters) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) - .thenReturn(authorization); - - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); - OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken( - AUTHORIZATION_CODE, - registeredClientRequiresProofKey.getClientId(), - authorizationRequest.getRedirectUri(), - Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) - ); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); - } - - @Test - public void authenticateWhenPublicClientAndClientIdNotMatchingThrowOAuth2AuthenticationException() { - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersPlain()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) - .thenReturn(authorization); - - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); - OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken( - AUTHORIZATION_CODE, - otherRegisteredClient.getClientId(), - authorizationRequest.getRedirectUri(), - Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) - ); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - } - - @Test - public void authenticateWhenPublicClientAndUnknownClientIdThrowOAuth2AuthenticationException() { - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersPlain()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) - .thenReturn(authorization); - - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); - OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken( - AUTHORIZATION_CODE, - "invalid-client-id", - authorizationRequest.getRedirectUri(), - Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_CHALLENGE) - ); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); - } - @Test public void authenticateWhenPublicClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .authorization(this.registeredClient, createPkceParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( AUTHORIZATION_CODE, @@ -345,15 +321,17 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { } @Test - public void authenticateWhenPrivateClientAndRequireProofKeyAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { + public void authenticateWhenConfidentialClientRequireProofKeyAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .authorization(this.registeredClientRequiresProofKey, createPkceParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + this.registeredClientRequiresProofKey); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( AUTHORIZATION_CODE, @@ -371,12 +349,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenPublicClientAndPlainMethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .authorization(this.registeredClient, createPkceParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier"); + OAuth2AuthorizationCodeAuthenticationToken authentication = createAuthorizationCodeAuthentication( + this.registeredClient, "invalid-code-verifier"); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -387,13 +366,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenPublicClientAndS256MethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersS256()) + .authorization(this.registeredClient, createPkceParametersS256()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier"); - + OAuth2AuthorizationCodeAuthenticationToken authentication = createAuthorizationCodeAuthentication( + this.registeredClient, "invalid-code-verifier"); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -401,16 +380,44 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); } + @Test + public void authenticateWhenRequireProofKeyAndUnsupportedCodeChallengeMethodThenThrowOAuth2AuthenticationException() { + Map pkceParameters = createPkceParametersPlain(); + // This should never happen: the Authorization endpoint should not allow it + pkceParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method"); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(this.registeredClientRequiresProofKey, pkceParameters) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken( + AUTHORIZATION_CODE, + this.registeredClientRequiresProofKey.getClientId(), + authorizationRequest.getRedirectUri(), + Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER) + ); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + } + @Test public void authenticateWhenPublicClientAndPlainMethodAndValidCodeVerifierThenReturnAccessToken() { OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersPlain()) + .authorization(this.registeredClient, createPkceParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); - OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER); + OAuth2AuthorizationCodeAuthenticationToken authentication = createAuthorizationCodeAuthentication( + this.registeredClient, PLAIN_CODE_VERIFIER); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -419,22 +426,23 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); - assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); + assertThat(clientAuthentication.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); assertThat(updatedAuthorization.getAccessToken()).isNotNull(); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); } @Test - public void authenticateWhenPublicClientAndNoMethodThenDefaultToPlainAndReturnAccessToken() { + public void authenticateWhenPublicClientAndMissingMethodThenDefaultPlainMethodAndReturnAccessToken() { OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, Collections.singletonMap(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE)) + .authorization(this.registeredClient, Collections.singletonMap(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE)) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); - OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER); + OAuth2AuthorizationCodeAuthenticationToken authentication = createAuthorizationCodeAuthentication( + this.registeredClient, PLAIN_CODE_VERIFIER); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -443,54 +451,53 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); - assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); + assertThat(clientAuthentication.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); assertThat(updatedAuthorization.getAccessToken()).isNotNull(); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); } - @Test public void authenticateWhenPublicClientAndS256MethodAndValidCodeVerifierThenReturnAccessToken() { OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, getPkceAuthorizationParametersS256()) + .authorization(this.registeredClient, createPkceParametersS256()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); - OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(S256_CODE_VERIFIER); - + OAuth2AuthorizationCodeAuthenticationToken authentication = createAuthorizationCodeAuthentication( + this.registeredClient, S256_CODE_VERIFIER); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); - ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); - assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal(); + assertThat(clientAuthentication.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); assertThat(updatedAuthorization.getAccessToken()).isNotNull(); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); } - private Map getPkceAuthorizationParametersPlain() { - Map additionalParameters = new HashMap<>(); - additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); - additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); - return additionalParameters; + private static Map createPkceParametersPlain() { + Map parameters = new HashMap<>(); + parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + parameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); + return parameters; } - private Map getPkceAuthorizationParametersS256() { - Map additionalParameters = new HashMap<>(); - additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE); - return additionalParameters; + private static Map createPkceParametersS256() { + Map parameters = new HashMap<>(); + parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + parameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE); + return parameters; } - private OAuth2AuthorizationCodeAuthenticationToken makeAuthorizationCodeAuthenticationToken(String codeVerifier) { + private static OAuth2AuthorizationCodeAuthenticationToken createAuthorizationCodeAuthentication( + RegisteredClient registeredClient, String codeVerifier) { return new OAuth2AuthorizationCodeAuthenticationToken( AUTHORIZATION_CODE, registeredClient.getClientId(), @@ -499,7 +506,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { ); } - private Jwt createJwt() { + private static Jwt createJwt() { Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); return Jwt.withTokenValue("token") diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java index a72840c..5d6b3fc 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java @@ -29,15 +29,15 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}. * * @author Joe Grandja + * @author Daniel Garnier-Moiroux */ public class OAuth2AuthorizationCodeAuthenticationTokenTests { private String code = "code"; - private String clientPrincipalClientId = "clientPrincipal.clientId"; - private OAuth2ClientAuthenticationToken clientPrincipal = - new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().clientId(clientPrincipalClientId).build()); + private OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + TestRegisteredClients.registeredClient().build()); private String clientId = "clientId"; private String redirectUri = "redirectUri"; - private Map additonalParams = Collections.singletonMap("some_key", "some_value"); + private Map additionalParameters = Collections.singletonMap("param1", "value1"); @Test public void constructorWhenCodeNullThenThrowIllegalArgumentException() { @@ -63,29 +63,29 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests { @Test public void constructorWhenClientPrincipalProvidedThenCreated() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.code, this.clientPrincipal, this.redirectUri, this.additonalParams); + this.code, this.clientPrincipal, this.redirectUri, this.additionalParameters); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getCode()).isEqualTo(this.code); assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri); - assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams); + assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters); } @Test public void constructorWhenClientIdProvidedThenCreated() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.code, this.clientId, this.redirectUri, this.additonalParams); + this.code, this.clientId, this.redirectUri, this.additionalParameters); assertThat(authentication.getPrincipal()).isEqualTo(this.clientId); assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getCode()).isEqualTo(this.code); assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri); - assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams); + assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters); } @Test - public void getAdditionalParamsIsImmutableMap() { + public void getAdditionalParametersWhenUpdateThenThrowUnsupportedOperationException() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.code, this.clientId, this.redirectUri, this.additonalParams); + this.code, this.clientId, this.redirectUri, this.additionalParameters); assertThatThrownBy(() -> authentication.getAdditionalParameters().put("another_key", 1)) .isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> authentication.getAdditionalParameters().remove("some_key")) @@ -95,18 +95,10 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests { } @Test - public void getClientIdFromClientId() { + public void getClientIdWhenClientPrincipalProvidedThenNotNull() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.code, this.clientId, this.redirectUri, this.additonalParams); - - assertThat(authentication.getClientId()).isEqualTo(this.clientId); - } - - @Test - public void getClientIdFromOAuth2ClientAuthenticationTokenPrincipal() { - OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.code, this.clientPrincipal, this.redirectUri, this.additonalParams); - - assertThat(authentication.getClientId()).isEqualTo(this.clientPrincipalClientId); + this.code, this.clientPrincipal, this.redirectUri, this.additionalParameters); + assertThat(authentication.getClientId()).isNotNull(); + assertThat(authentication.getClientId()).isEqualTo(this.clientPrincipal.getRegisteredClient().getClientId()); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 3c5c03e..71dfd8d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -60,9 +60,12 @@ import static org.mockito.Mockito.when; * * @author Paurav Munshi * @author Joe Grandja + * @author Daniel Garnier-Moiroux * @since 0.0.1 */ public class OAuth2AuthorizationEndpointFilterTests { + private static final String DEFAULT_ERROR_URI = "https://tools.ietf.org/html/rfc6749%23section-4.1.2.1"; + private static final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636%23section-4.4.1"; private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private OAuth2AuthorizationEndpointFilter filter; @@ -219,21 +222,12 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20response_type&" + - "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + OAuth2ParameterNames.RESPONSE_TYPE, + OAuth2ErrorCodes.INVALID_REQUEST, + DEFAULT_ERROR_URI, + request -> request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE)); } @Test @@ -242,21 +236,12 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20response_type&" + - "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + OAuth2ParameterNames.RESPONSE_TYPE, + OAuth2ErrorCodes.INVALID_REQUEST, + DEFAULT_ERROR_URI, + request -> request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token")); } @Test @@ -265,191 +250,139 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=unsupported_response_type&" + - "error_description=OAuth%202.0%20Parameter:%20response_type&" + - "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + OAuth2ParameterNames.RESPONSE_TYPE, + OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, + DEFAULT_ERROR_URI, + request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token")); } @Test - public void doFilterWhenProofKeyRequiredAndMissingPkceCodeChallengeThenThrowError() throws Exception { - RegisteredClient registeredClient = createClientRequireProofKey(); + public void doFilterWhenPkceRequiredAndMissingCodeChallengeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(new ClientSettings().requireProofKey(true)) + .build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); - request.removeParameter(PkceParameterNames.CODE_CHALLENGE); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20code_challenge&" + - "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + PkceParameterNames.CODE_CHALLENGE, + OAuth2ErrorCodes.INVALID_REQUEST, + PKCE_ERROR_URI, + request -> { + addPkceParameters(request); + request.removeParameter(PkceParameterNames.CODE_CHALLENGE); + }); } @Test - public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeThenThrowError() throws Exception { - RegisteredClient registeredClient = createClientRequireProofKey(); + public void doFilterWhenPkceRequiredAndMultipleCodeChallengeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(new ClientSettings().requireProofKey(true)) + .build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); - request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20code_challenge&" + - "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + PkceParameterNames.CODE_CHALLENGE, + OAuth2ErrorCodes.INVALID_REQUEST, + PKCE_ERROR_URI, + request -> { + addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge"); + }); } @Test - public void doFilterWhenProofKeyNotRequiredClientAndMultiplePkceCodeChallengeThenThrowError() throws Exception { - RegisteredClient registeredClient = createClientDoNotRequireProofKey(); + public void doFilterWhenPkceNotRequiredAndMultipleCodeChallengeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); - request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20code_challenge&" + - "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + - "state=state"); - + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + PkceParameterNames.CODE_CHALLENGE, + OAuth2ErrorCodes.INVALID_REQUEST, + PKCE_ERROR_URI, + request -> { + addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge"); + }); } @Test - public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeMethodThenThrowError() throws Exception { - RegisteredClient registeredClient = createClientRequireProofKey(); + public void doFilterWhenPkceRequiredAndMultipleCodeChallengeMethodThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(new ClientSettings().requireProofKey(true)) + .build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); - request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + - "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + PkceParameterNames.CODE_CHALLENGE_METHOD, + OAuth2ErrorCodes.INVALID_REQUEST, + PKCE_ERROR_URI, + request -> { + addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + }); } @Test - public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAnMultiplePkceCodeChallengeMethodThenThrowError() throws Exception { - RegisteredClient registeredClient = createClientDoNotRequireProofKey(); + public void doFilterWhenPkceNotRequiredAndMultipleCodeChallengeMethodThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); - request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + - "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + PkceParameterNames.CODE_CHALLENGE_METHOD, + OAuth2ErrorCodes.INVALID_REQUEST, + PKCE_ERROR_URI, + request -> { + addPkceParameters(request); + request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + }); } @Test - public void doFilterWhenProofKeyRequiredAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception { - RegisteredClient registeredClient = createClientRequireProofKey(); + public void doFilterWhenPkceRequiredAndUnsupportedCodeChallengeMethodThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(new ClientSettings().requireProofKey(true)) + .build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); - request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + - "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + - "state=state"); + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + PkceParameterNames.CODE_CHALLENGE_METHOD, + OAuth2ErrorCodes.INVALID_REQUEST, + PKCE_ERROR_URI, + request -> { + addPkceParameters(request); + request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported"); + }); } @Test - public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception { - RegisteredClient registeredClient = createClientDoNotRequireProofKey(); + public void doFilterWhenPkceNotRequiredAndUnsupportedCodeChallengeMethodThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); - request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + - "error=invalid_request&" + - "error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" + - "error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" + - "state=state"); - + doFilterWhenAuthorizationRequestInvalidParameterThenRedirected( + registeredClient, + PkceParameterNames.CODE_CHALLENGE_METHOD, + OAuth2ErrorCodes.INVALID_REQUEST, + PKCE_ERROR_URI, + request -> { + addPkceParameters(request); + request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported"); + }); } @Test @@ -510,13 +443,15 @@ public class OAuth2AuthorizationEndpointFilterTests { } @Test - public void doFilterWhenProofKeyRequiredAndAuthorizationRequestValidThenAuthorizationResponse() throws Exception { - RegisteredClient registeredClient = createClientRequireProofKey(); + public void doFilterWhenPkceRequiredAndAuthorizationRequestValidThenAuthorizationResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(new ClientSettings().requireProofKey(true)) + .build(); when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request = addPkceParameters(request); + addPkceParameters(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -564,6 +499,27 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName); } + private void doFilterWhenAuthorizationRequestInvalidParameterThenRedirected(RegisteredClient registeredClient, + String parameterName, String errorCode, String errorUri, + Consumer requestConsumer) throws Exception { + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + requestConsumer.accept(request); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=" + errorCode + "&" + + "error_description=OAuth%202.0%20Parameter:%20" + parameterName + "&" + + "error_uri=" + errorUri + "&" + + "state=state"); + } + private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); @@ -581,29 +537,8 @@ public class OAuth2AuthorizationEndpointFilterTests { return request; } - private static MockHttpServletRequest addPkceParameters(MockHttpServletRequest request) { + private static void addPkceParameters(MockHttpServletRequest request) { request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - - return request; } - - private RegisteredClient createClientRequireProofKey() { - ClientSettings clientSettings = new ClientSettings(); - clientSettings.requireProofKey(true); - - return TestRegisteredClients.registeredClient() - .clientSettings(clientSettings) - .build(); - } - - private RegisteredClient createClientDoNotRequireProofKey() { - ClientSettings clientSettings = new ClientSettings(); - clientSettings.requireProofKey(false); - - return TestRegisteredClients.registeredClient() - .clientSettings(clientSettings) - .build(); - } - } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index 67b8bef..409816d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -67,6 +67,7 @@ import static org.mockito.Mockito.when; * * @author Madhu Bhat * @author Joe Grandja + * @author Daniel Garnier-Moiroux */ public class OAuth2TokenEndpointFilterTests { private AuthenticationManager authenticationManager; @@ -179,12 +180,6 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); - Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); - SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - securityContext.setAuthentication(clientPrincipal); - SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); request.removeParameter(OAuth2ParameterNames.CODE); @@ -195,12 +190,6 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); - Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); - SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - securityContext.setAuthentication(clientPrincipal); - SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); request.addParameter(OAuth2ParameterNames.CODE, "code-2"); @@ -211,12 +200,6 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); - Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); - SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - securityContext.setAuthentication(clientPrincipal); - SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); @@ -227,12 +210,8 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestNotAuthenticatedAndMissingCodeVerifierThenInvalidRequestError() throws Exception { - SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com"); doFilterWhenTokenRequestInvalidParameterThenError( PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request); @@ -240,14 +219,10 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestNotAuthenticatedAndMultipleCodeVerifierThenInvalidRequestError() throws Exception { - SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createAuthorizationCodeTokenRequest( TestRegisteredClients.registeredClient().build()); request.addParameter(PkceParameterNames.CODE_VERIFIER, "one-verifier"); - request.addParameter(PkceParameterNames.CODE_VERIFIER, "two-verifiers"); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com"); + request.addParameter(PkceParameterNames.CODE_VERIFIER, "two-verifier2"); doFilterWhenTokenRequestInvalidParameterThenError( PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request);