Implement Proof Key for Code Exchange (PKCE) RFC 7636

See https://tools.ietf.org/html/rfc7636

Closes gh-45
This commit is contained in:
Daniel Garnier-Moiroux 2020-06-22 21:35:01 +02:00 committed by Joe Grandja
parent 8541f6be69
commit ab090445b3
13 changed files with 1008 additions and 55 deletions

View File

@ -38,8 +38,8 @@ import java.util.Collections;
*/
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private final RegisteredClient registeredClient;
private final Authentication clientPrincipal;
private RegisteredClient registeredClient;
private Authentication clientPrincipal;
private final OAuth2AccessToken accessToken;
/**

View File

@ -23,6 +23,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.JoseHeader;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@ -33,15 +34,20 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Collections;
/**
@ -85,29 +91,30 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
(OAuth2AuthorizationCodeAuthenticationToken) authentication;
OAuth2ClientAuthenticationToken clientPrincipal = null;
RegisteredClient registeredClient = null;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
}
if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
registeredClient = clientPrincipal.getRegisteredClient();
} else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) {
// When the principal is a string, it is the clientId, REQUIRED for public clients
String clientId = authorizationCodeAuthentication.getClientId();
registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
} else {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
// TODO Authenticate public client
// A client MAY use the "client_id" request parameter to identify itself
// when sending requests to the token endpoint.
// In the "authorization_code" "grant_type" request to the token endpoint,
// an unauthenticated client MUST send its "client_id" to prevent itself
// from inadvertently accepting a code intended for a client with a different "client_id".
// This protects the client from substitution of the authentication code.
if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
OAuth2Authorization authorization = this.authorizationService.findByToken(
authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
if (authorization == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@ -116,6 +123,35 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
String codeChallenge;
Object codeChallengeParameter = authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE);
if (codeChallengeParameter != null) {
codeChallenge = (String) codeChallengeParameter;
String codeChallengeMethod = (String) authorizationRequest
.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
String codeVerifier = (String) authorizationCodeAuthentication
.getAdditionalParameters()
.get(PkceParameterNames.CODE_VERIFIER);
if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
} else if (registeredClient.getClientSettings().requireProofKey()){
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
// TODO Allow configuration for issuer claim
@ -130,7 +166,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims()
.issuer(issuer)
.subject(authorization.getPrincipalName())
.audience(Collections.singletonList(clientPrincipal.getRegisteredClient().getClientId()))
.audience(Collections.singletonList(registeredClient.getClientId()))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.notBefore(issuedAt)
@ -148,8 +184,30 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
.build();
this.authorizationService.save(authorization);
return new OAuth2AccessTokenAuthenticationToken(
clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
return clientPrincipal != null ?
new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken) :
new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken);
}
private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
if (codeVerifier == null) {
return false;
} else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) {
return codeVerifier.equals(codeChallenge);
} else if ("S256".equals(codeChallengeMethod)) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
return codeChallenge.equals(encodedVerifier);
} catch (NoSuchAlgorithmException e) {
// It is unlikely that SHA-256 is not available on the server. If it is not available,
// there will likely be bigger issues as well. We default to SERVER_ERROR.
}
}
// Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR));
}
@Override

View File

@ -22,6 +22,7 @@ import org.springframework.security.core.SpringSecurityCoreVersion2;
import org.springframework.util.Assert;
import java.util.Collections;
import java.util.Map;
/**
* An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
@ -35,10 +36,11 @@ import java.util.Collections;
*/
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
private String code;
private final String code;
private Authentication clientPrincipal;
private String clientId;
private String redirectUri;
private final String clientId;
private final String redirectUri;
private final Map<String, Object> additionalParameters;
/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
@ -46,15 +48,24 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
* @param code the authorization code
* @param clientPrincipal the authenticated client principal
* @param redirectUri the redirect uri
* @param additionalParameters the additional parameters
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code,
Authentication clientPrincipal, @Nullable String redirectUri) {
Authentication clientPrincipal, @Nullable String redirectUri,
Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
this.code = code;
this.clientPrincipal = clientPrincipal;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) {
this.clientId = (String) this.clientPrincipal.getPrincipal();
} else {
this.clientId = null;
}
}
/**
@ -63,15 +74,18 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
* @param code the authorization code
* @param clientId the client identifier
* @param redirectUri the redirect uri
* @param additionalParameters the additional parameters
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code,
String clientId, @Nullable String redirectUri) {
String clientId, @Nullable String redirectUri,
Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.hasText(clientId, "clientId cannot be empty");
this.code = code;
this.clientId = clientId;
this.redirectUri = redirectUri;
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
}
@Override
@ -101,4 +115,22 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
public @Nullable String getRedirectUri() {
return this.redirectUri;
}
/**
* Returns the additional parameters
*
* @return the additional parameters
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
/**
* Returns the client id
*
* @return the client id
*/
public @Nullable String getClientId() {
return this.clientId;
}
}

View File

@ -367,7 +367,6 @@ public class RegisteredClient implements Serializable {
Assert.hasText(this.clientId, "clientId cannot be empty");
Assert.notEmpty(this.authorizationGrantTypes, "authorizationGrantTypes cannot be empty");
if (this.authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
Assert.hasText(this.clientSecret, "clientSecret cannot be empty");
Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty");
}
if (CollectionUtils.isEmpty(this.clientAuthenticationMethods)) {

View File

@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@ -78,6 +79,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
private final RequestMatcher authorizationEndpointMatcher;
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1";
/**
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
@ -174,6 +176,34 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
return;
}
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE);
if (StringUtils.hasText(codeChallenge)) {
if (parameters.get(PkceParameterNames.CODE_CHALLENGE).size() != 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}
if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null &&
parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}
String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (codeChallengeMethod != null && !Arrays.asList("plain", "S256").contains(codeChallengeMethod)) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}
} else if (registeredClient.getClientSettings().requireProofKey()) {
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
sendErrorResponse(request, response, error, stateParameter, redirectUri);
return;
}
// ---------------
// The request is valid - ensure the resource owner is authenticated
// ---------------
@ -245,8 +275,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
}
private static OAuth2Error createError(String errorCode, String parameterName) {
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
return createError(errorCode, parameterName, "https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
}
private static OAuth2Error createError(String errorCode, String parameterName, String errorUri) {
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
}
private static boolean isPrincipalAuthenticated(Authentication principal) {

View File

@ -30,11 +30,13 @@ import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
@ -54,6 +56,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
@ -198,14 +201,22 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// client_id (REQUIRED)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
Authentication clientPrincipal = null;
if (StringUtils.hasText(clientId)) {
if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
String clientId = null;
if (clientPrincipal == null ||
!OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass())) {
clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId) ||
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
}
} else {
clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
// code_verifier (REQUIRED for public clients)
String codeVerifier = parameters.getFirst(PkceParameterNames.CODE_VERIFIER);
if (!StringUtils.hasText(codeVerifier) ||
parameters.get(PkceParameterNames.CODE_VERIFIER).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_VERIFIER);
}
}
// code (REQUIRED)
@ -223,9 +234,19 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
}
return clientPrincipal != null ?
new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) :
new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri);
Map<String, Object> additionalParameters = parameters
.entrySet()
.stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
!e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
!e.getKey().equals(OAuth2ParameterNames.CODE) &&
!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI))
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
return clientId != null ?
new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri, additionalParameters) :
new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters);
}
}

View File

@ -31,6 +31,7 @@ import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@ -169,6 +170,7 @@ public class OAuth2AuthorizationCodeGrantTests {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
parameters.set(OAuth2ParameterNames.STATE, "state");
parameters.set(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
return parameters;
}

View File

@ -0,0 +1,178 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.security.config.annotation.web.WebSecurityConfigurer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.builders.WebSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.crypto.keys.KeyManager;
import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
public class OAuth2PkceTests {
private static RegisteredClientRepository registeredClientRepository;
@Rule
public final SpringTestRule spring = new SpringTestRule();
@Autowired
private MockMvc mvc;
@BeforeClass
public static void init() {
registeredClientRepository = mock(RegisteredClientRepository.class);
}
@Before
public void setup() {
reset(registeredClientRepository);
}
@Test
public void requestWhenTokenRequestNotAuthenticatedAndPkceParamatersProvidedThenRedirectToClient() throws Exception {
// See RFC 7636: Appendix B. Example for the S256 code_challenge_method
// https://tools.ietf.org/html/rfc7636#appendix-B
final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
this.spring.register(AuthorizationServerConfiguration.class).autowire();
ClientSettings settings = new ClientSettings();
RegisteredClient registeredClient = TestRegisteredClients
.registeredClient()
.clientSettings(settings.requireProofKey(true))
.build();
when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))
.param("code_challenge", S256_CODE_CHALLENGE)
.param("code_challenge_method", "S256")
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl())
.doesNotContain("error=")
.contains("code=");
String authorizationCode = UriUtils.decode(UriComponentsBuilder.fromHttpUrl(mvcResult.getResponse().getRedirectedUrl())
.build()
.getQueryParams()
.getFirst("code"), "utf-8");
this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCode))
.param("code_verifier", S256_CODE_VERIFIER))
.andExpect(status().is2xxSuccessful())
.andExpect(jsonPath("$.access_token").isNotEmpty());
}
private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
parameters.set(OAuth2ParameterNames.STATE, "state");
return parameters;
}
private static MultiValueMap<String, String> getTokenRequestParameters(RegisteredClient registeredClient,
String authorizationCode) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
parameters.set(OAuth2ParameterNames.CODE, authorizationCode);
parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
return parameters;
}
@EnableWebSecurity
static class AuthorizationServerConfiguration {
@Bean
RegisteredClientRepository registeredClientRepository() {
return registeredClientRepository;
}
@Bean
OAuth2AuthorizationService authorizationService() {
return new InMemoryOAuth2AuthorizationService();
}
@Bean
KeyManager keyManager() {
return new StaticKeyGeneratingKeyManager();
}
@Bean
WebSecurityConfigurer<WebSecurity> defaultOAuth2AuthorizationServerSecurity() {
return new WebSecurityConfigurerAdapter() {
@Override
public void configure(HttpSecurity http) throws Exception {
http
.authorizeRequests()
.anyRequest()
.permitAll()
.and()
.csrf()
.disable()
.apply(new OAuth2AuthorizationServerConfigurer<>());
}
};
}
}
}

View File

@ -21,6 +21,8 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import java.time.Instant;
import java.util.Collections;
import java.util.Map;
/**
* @author Joe Grandja
@ -32,12 +34,17 @@ public class TestOAuth2Authorizations {
}
public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient) {
return authorization(registeredClient, Collections.emptyMap());
}
public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, Map<String, Object> additionalParameters) {
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://provider.com/oauth2/authorize")
.clientId(registeredClient.getClientId())
.redirectUri(registeredClient.getRedirectUris().iterator().next())
.additionalParameters(additionalParameters)
.state("state")
.build();
return OAuth2Authorization.withRegisteredClient(registeredClient)

View File

@ -22,6 +22,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.jose.JoseHeaderNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
@ -35,6 +36,11 @@ import org.springframework.security.oauth2.server.authorization.client.InMemoryR
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
@ -53,7 +59,19 @@ import static org.mockito.Mockito.when;
* @author Joe Grandja
*/
public class OAuth2AuthorizationCodeAuthenticationProviderTests {
private final String PLAIN_CODE_CHALLENGE = "pkce-key";
private final String PLAIN_CODE_VERIFIER = PLAIN_CODE_CHALLENGE;
// See RFC 7636: Appendix B. Example for the S256 code_challenge_method
// https://tools.ietf.org/html/rfc7636#appendix-B
private final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
private final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
private final String AUTHORIZATION_CODE = "code";
private RegisteredClient registeredClient;
private RegisteredClient otherRegisteredClient;
private RegisteredClient registeredClientRequiresProofKey;
private RegisteredClientRepository registeredClientRepository;
private OAuth2AuthorizationService authorizationService;
private JwtEncoder jwtEncoder;
@ -62,7 +80,17 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
@Before
public void setUp() {
this.registeredClient = TestRegisteredClients.registeredClient().build();
this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient);
this.otherRegisteredClient = TestRegisteredClients.registeredClient2().build();
this.registeredClientRequiresProofKey = TestRegisteredClients.registeredClient()
.id("registration-3")
.clientId("client-3")
.clientSettings(new ClientSettings().requireProofKey(true))
.build();
this.registeredClientRepository = new InMemoryRegisteredClientRepository(
this.registeredClient,
this.otherRegisteredClient,
this.registeredClientRequiresProofKey
);
this.authorizationService = mock(OAuth2AuthorizationService.class);
this.jwtEncoder = mock(JwtEncoder.class);
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
@ -100,7 +128,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@ -113,7 +141,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@ -125,7 +153,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@ -142,7 +170,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
TestRegisteredClients.registeredClient2().build());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@ -160,7 +188,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid");
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid", null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@ -178,16 +206,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri());
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null);
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
Jwt jwt = Jwt.withTokenValue("token")
.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.build();
when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -201,4 +222,290 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
assertThat(updatedAuthorization.getAccessToken()).isNotNull();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
}
@Test
public void authenticateWhenRequireProofKeyAndMissingPkceCodeChallengeInAuthorizationRequestThenThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClientRequiresProofKey).build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE,
registeredClientRequiresProofKey.getClientId(),
authorizationRequest.getRedirectUri(),
Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER)
);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenRequireProofKeyAndUnsupportedCodeChallengeMethodInAuthorizationRequestThenThrowOAuth2AuthenticationException() {
Map<String, Object> pkceParameters = new HashMap<>();
pkceParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE);
// This should never happen: the Authorization endpoint should not allow it
pkceParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method");
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClientRequiresProofKey, pkceParameters)
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE,
registeredClientRequiresProofKey.getClientId(),
authorizationRequest.getRedirectUri(),
Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER)
);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
}
@Test
public void authenticateWhenPublicClientAndClientIdNotMatchingThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersPlain())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE,
otherRegisteredClient.getClientId(),
authorizationRequest.getRedirectUri(),
Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER)
);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenPublicClientAndUnknownClientIdThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersPlain())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE,
"invalid-client-id",
authorizationRequest.getRedirectUri(),
Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_CHALLENGE)
);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
}
@Test
public void authenticateWhenPublicClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersPlain())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE,
authorizationRequest.getClientId(),
authorizationRequest.getRedirectUri(),
null
);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenPrivateClientAndRequireProofKeyAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersPlain())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE,
clientPrincipal,
authorizationRequest.getRedirectUri(),
null
);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenPublicClientAndPlainMethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersPlain())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier");
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenPublicClientAndS256MethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersS256())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier");
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenPublicClientAndPlainMethodAndValidCodeVerifierThenReturnAccessToken() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersPlain())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal();
assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId());
assertThat(updatedAuthorization.getAccessToken()).isNotNull();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
}
@Test
public void authenticateWhenPublicClientAndNoMethodThenDefaultToPlainAndReturnAccessToken() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, Collections.singletonMap(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE))
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal();
assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId());
assertThat(updatedAuthorization.getAccessToken()).isNotNull();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
}
@Test
public void authenticateWhenPublicClientAndS256MethodAndValidCodeVerifierThenReturnAccessToken() {
OAuth2Authorization authorization = TestOAuth2Authorizations
.authorization(registeredClient, getPkceAuthorizationParametersS256())
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(S256_CODE_VERIFIER);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal();
assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId());
assertThat(updatedAuthorization.getAccessToken()).isNotNull();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
}
private Map<String, Object> getPkceAuthorizationParametersPlain() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain");
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE);
return additionalParameters;
}
private Map<String, Object> getPkceAuthorizationParametersS256() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
return additionalParameters;
}
private OAuth2AuthorizationCodeAuthenticationToken makeAuthorizationCodeAuthenticationToken(String codeVerifier) {
return new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE,
registeredClient.getClientId(),
registeredClient.getRedirectUris().iterator().next(),
Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, codeVerifier)
);
}
private Jwt createJwt() {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
return Jwt.withTokenValue("token")
.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.build();
}
}

View File

@ -19,6 +19,9 @@ import org.junit.Test;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import java.util.Collections;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -29,28 +32,30 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
*/
public class OAuth2AuthorizationCodeAuthenticationTokenTests {
private String code = "code";
private String clientPrincipalClientId = "clientPrincipal.clientId";
private OAuth2ClientAuthenticationToken clientPrincipal =
new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().clientId(clientPrincipalClientId).build());
private String clientId = "clientId";
private String redirectUri = "redirectUri";
private Map<String, Object> additonalParams = Collections.singletonMap("some_key", "some_value");
@Test
public void constructorWhenCodeNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri))
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("code cannot be empty");
}
@Test
public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri))
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientPrincipal cannot be null");
}
@Test
public void constructorWhenClientIdNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri))
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientId cannot be empty");
}
@ -58,20 +63,50 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests {
@Test
public void constructorWhenClientPrincipalProvidedThenCreated() {
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
this.code, this.clientPrincipal, this.redirectUri);
this.code, this.clientPrincipal, this.redirectUri, this.additonalParams);
assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getCode()).isEqualTo(this.code);
assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams);
}
@Test
public void constructorWhenClientIdProvidedThenCreated() {
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
this.code, this.clientId, this.redirectUri);
this.code, this.clientId, this.redirectUri, this.additonalParams);
assertThat(authentication.getPrincipal()).isEqualTo(this.clientId);
assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getCode()).isEqualTo(this.code);
assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams);
}
@Test
public void getAdditionalParamsIsImmutableMap() {
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
this.code, this.clientId, this.redirectUri, this.additonalParams);
assertThatThrownBy(() -> authentication.getAdditionalParameters().put("another_key", 1))
.isInstanceOf(UnsupportedOperationException.class);
assertThatThrownBy(() -> authentication.getAdditionalParameters().remove("some_key"))
.isInstanceOf(UnsupportedOperationException.class);
assertThatThrownBy(() -> authentication.getAdditionalParameters().clear())
.isInstanceOf(UnsupportedOperationException.class);
}
@Test
public void getClientIdFromClientId() {
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
this.code, this.clientId, this.redirectUri, this.additonalParams);
assertThat(authentication.getClientId()).isEqualTo(this.clientId);
}
@Test
public void getClientIdFromOAuth2ClientAuthenticationTokenPrincipal() {
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
this.code, this.clientPrincipal, this.redirectUri, this.additonalParams);
assertThat(authentication.getClientId()).isEqualTo(this.clientPrincipalClientId);
}
}

View File

@ -30,12 +30,14 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
import org.springframework.util.StringUtils;
import javax.servlet.FilterChain;
@ -280,6 +282,176 @@ public class OAuth2AuthorizationEndpointFilterTests {
"state=state");
}
@Test
public void doFilterWhenProofKeyRequiredAndMissingPkceCodeChallengeThenThrowError() throws Exception {
RegisteredClient registeredClient = createClientRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
request.removeParameter(PkceParameterNames.CODE_CHALLENGE);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20code_challenge&" +
"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
"state=state");
}
@Test
public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeThenThrowError() throws Exception {
RegisteredClient registeredClient = createClientRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20code_challenge&" +
"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
"state=state");
}
@Test
public void doFilterWhenProofKeyNotRequiredClientAndMultiplePkceCodeChallengeThenThrowError() throws Exception {
RegisteredClient registeredClient = createClientDoNotRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20code_challenge&" +
"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
"state=state");
}
@Test
public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeMethodThenThrowError() throws Exception {
RegisteredClient registeredClient = createClientRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
"state=state");
}
@Test
public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAnMultiplePkceCodeChallengeMethodThenThrowError() throws Exception {
RegisteredClient registeredClient = createClientDoNotRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
"state=state");
}
@Test
public void doFilterWhenProofKeyRequiredAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception {
RegisteredClient registeredClient = createClientRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
"state=state");
}
@Test
public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception {
RegisteredClient registeredClient = createClientDoNotRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
"error=invalid_request&" +
"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
"state=state");
}
@Test
public void doFilterWhenAuthorizationRequestValidNotAuthenticatedThenContinueChainToCommenceAuthentication() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@ -337,6 +509,40 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
}
@Test
public void doFilterWhenProofKeyRequiredAndAuthorizationRequestValidThenAuthorizationResponse() throws Exception {
RegisteredClient registeredClient = createClientRequireProofKey();
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request = addPkceParameters(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization authorization = authorizationCaptor.getValue();
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
assertThat(authorizationRequest.getAdditionalParameters())
.size()
.isEqualTo(2)
.returnToMap()
.containsEntry(PkceParameterNames.CODE_CHALLENGE, "code-challenge")
.containsEntry(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
}
private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
String parameterName, String errorCode) throws Exception {
doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {});
@ -374,4 +580,30 @@ public class OAuth2AuthorizationEndpointFilterTests {
return request;
}
private static MockHttpServletRequest addPkceParameters(MockHttpServletRequest request) {
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
return request;
}
private RegisteredClient createClientRequireProofKey() {
ClientSettings clientSettings = new ClientSettings();
clientSettings.requireProofKey(true);
return TestRegisteredClients.registeredClient()
.clientSettings(clientSettings)
.build();
}
private RegisteredClient createClientDoNotRequireProofKey() {
ClientSettings clientSettings = new ClientSettings();
clientSettings.requireProofKey(false);
return TestRegisteredClients.registeredClient()
.clientSettings(clientSettings)
.build();
}
}

View File

@ -34,6 +34,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@ -178,6 +179,12 @@ public class OAuth2TokenEndpointFilterTests {
@Test
public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(clientPrincipal);
SecurityContextHolder.setContext(securityContext);
MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
TestRegisteredClients.registeredClient().build());
request.removeParameter(OAuth2ParameterNames.CODE);
@ -188,6 +195,12 @@ public class OAuth2TokenEndpointFilterTests {
@Test
public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(clientPrincipal);
SecurityContextHolder.setContext(securityContext);
MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
TestRegisteredClients.registeredClient().build());
request.addParameter(OAuth2ParameterNames.CODE, "code-2");
@ -198,6 +211,12 @@ public class OAuth2TokenEndpointFilterTests {
@Test
public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(clientPrincipal);
SecurityContextHolder.setContext(securityContext);
MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
TestRegisteredClients.registeredClient().build());
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
@ -206,6 +225,34 @@ public class OAuth2TokenEndpointFilterTests {
OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, request);
}
@Test
public void doFilterWhenTokenRequestNotAuthenticatedAndMissingCodeVerifierThenInvalidRequestError() throws Exception {
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
SecurityContextHolder.setContext(securityContext);
MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
TestRegisteredClients.registeredClient().build());
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com");
doFilterWhenTokenRequestInvalidParameterThenError(
PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request);
}
@Test
public void doFilterWhenTokenRequestNotAuthenticatedAndMultipleCodeVerifierThenInvalidRequestError() throws Exception {
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
SecurityContextHolder.setContext(securityContext);
MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
TestRegisteredClients.registeredClient().build());
request.addParameter(PkceParameterNames.CODE_VERIFIER, "one-verifier");
request.addParameter(PkceParameterNames.CODE_VERIFIER, "two-verifiers");
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com");
doFilterWhenTokenRequestInvalidParameterThenError(
PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request);
}
@Test
public void doFilterWhenAuthorizationCodeTokenRequestValidThenAccessTokenResponse() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@ -359,6 +406,8 @@ public class OAuth2TokenEndpointFilterTests {
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
// The client does not need to send the client ID param, but we are resilient in case they do
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
return request;
}