Enforce one-time use for authorization code

Closes gh-138
This commit is contained in:
Joe Grandja 2020-10-22 14:03:24 -04:00
parent 601640e4fa
commit 18f8b3afaa
14 changed files with 201 additions and 45 deletions

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.io.Serializable; import java.io.Serializable;
@ -63,7 +64,8 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) { if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) {
return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)); return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE));
} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { } else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) { } else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
return authorization.getTokens().getAccessToken() != null && return authorization.getTokens().getAccessToken() != null &&
authorization.getTokens().getAccessToken().getTokenValue().equals(token); authorization.getTokens().getAccessToken().getTokenValue().equals(token);

View File

@ -152,7 +152,7 @@ public class OAuth2Authorization implements Serializable {
Assert.notNull(authorization, "authorization cannot be null"); Assert.notNull(authorization, "authorization cannot be null");
return new Builder(authorization.getRegisteredClientId()) return new Builder(authorization.getRegisteredClientId())
.principalName(authorization.getPrincipalName()) .principalName(authorization.getPrincipalName())
.tokens(authorization.getTokens()) .tokens(OAuth2Tokens.from(authorization.getTokens()).build())
.attributes(attrs -> attrs.putAll(authorization.getAttributes())); .attributes(attrs -> attrs.putAll(authorization.getAttributes()));
} }

View File

@ -38,6 +38,7 @@ public interface OAuth2AuthorizationAttributeNames {
/** /**
* The name of the attribute used for the {@link OAuth2ParameterNames#CODE} parameter. * The name of the attribute used for the {@link OAuth2ParameterNames#CODE} parameter.
*/ */
@Deprecated
String CODE = OAuth2Authorization.class.getName().concat(".CODE"); String CODE = OAuth2Authorization.class.getName().concat(".CODE");
/** /**

View File

@ -35,6 +35,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -102,11 +104,15 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
if (authorization == null) { if (authorization == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) { if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
// Invalidate the authorization code given that a different client is attempting to use it
authorization.getTokens().invalidate(authorizationCode);
this.authorizationService.save(authorization);
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
@ -115,6 +121,12 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode);
if (authorizationCodeMetadata.isInvalidated()) {
// Prevent the same client from using the authorization code more than once
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
// TODO Allow configuration for issuer claim // TODO Allow configuration for issuer claim
@ -142,9 +154,14 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens())
.accessToken(accessToken)
.build();
tokens.invalidate(authorizationCode); // Invalidate the authorization code as it can only be used once
authorization = OAuth2Authorization.from(authorization) authorization = OAuth2Authorization.from(authorization)
.tokens(tokens)
.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);

View File

@ -0,0 +1,43 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import java.time.Instant;
/**
* An implementation of an {@link AbstractOAuth2Token}
* representing an OAuth 2.0 Authorization Code Grant.
*
* @author Joe Grandja
* @since 0.0.3
* @see AbstractOAuth2Token
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
*/
public class OAuth2AuthorizationCode extends AbstractOAuth2Token {
/**
* Constructs an {@code OAuth2AuthorizationCode} using the provided parameters.
* @param tokenValue the token value
* @param issuedAt the time at which the token was issued
* @param expiresAt the time at which the token expires
*/
public OAuth2AuthorizationCode(String tokenValue, Instant issuedAt, Instant expiresAt) {
super(tokenValue, issuedAt, expiresAt);
}
}

View File

@ -146,14 +146,30 @@ public class OAuth2Tokens implements Serializable {
return new Builder(); return new Builder();
} }
/**
* Returns a new {@link Builder}, initialized with the values from the provided {@code tokens}.
*
* @param tokens the tokens used for initializing the {@link Builder}
* @return the {@link Builder}
*/
public static Builder from(OAuth2Tokens tokens) {
Assert.notNull(tokens, "tokens cannot be null");
return new Builder(tokens.tokens);
}
/** /**
* A builder for {@link OAuth2Tokens}. * A builder for {@link OAuth2Tokens}.
*/ */
public static class Builder implements Serializable { public static class Builder implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private final Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens = new HashMap<>(); private Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens;
protected Builder() { protected Builder() {
this.tokens = new HashMap<>();
}
protected Builder(Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens) {
this.tokens = new HashMap<>(tokens);
} }
/** /**

View File

@ -36,6 +36,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@ -53,6 +55,8 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
@ -184,9 +188,12 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
UserConsentPage.displayConsent(request, response, registeredClient, authorization); UserConsentPage.displayConsent(request, response, registeredClient, authorization);
} else { } else {
String code = this.codeGenerator.generateKey(); Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES); // TODO Allow configuration for authorization code time-to-live
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
this.codeGenerator.generateKey(), issuedAt, expiresAt);
OAuth2Authorization authorization = builder OAuth2Authorization authorization = builder
.attribute(OAuth2AuthorizationAttributeNames.CODE, code) .tokens(OAuth2Tokens.builder().token(authorizationCode).build())
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes())
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
@ -200,7 +207,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
// The authorization code is bound to the client identifier and redirection URI. // The authorization code is bound to the client identifier and redirection URI.
sendAuthorizationResponse(request, response, sendAuthorizationResponse(request, response,
authorizationRequestContext.resolveRedirectUri(), code, authorizationRequest.getState()); authorizationRequestContext.resolveRedirectUri(), authorizationCode, authorizationRequest.getState());
} }
} }
@ -232,18 +239,21 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
return; return;
} }
String code = this.codeGenerator.generateKey(); Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES); // TODO Allow configuration for authorization code time-to-live
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
this.codeGenerator.generateKey(), issuedAt, expiresAt);
OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization()) OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization())
.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
.attributes(attrs -> { .attributes(attrs -> {
attrs.remove(OAuth2AuthorizationAttributeNames.STATE); attrs.remove(OAuth2AuthorizationAttributeNames.STATE);
attrs.put(OAuth2AuthorizationAttributeNames.CODE, code);
attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes()); attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes());
}) })
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
sendAuthorizationResponse(request, response, userConsentRequestContext.resolveRedirectUri(), sendAuthorizationResponse(request, response, userConsentRequestContext.resolveRedirectUri(),
code, userConsentRequestContext.getAuthorizationRequest().getState()); authorizationCode, userConsentRequestContext.getAuthorizationRequest().getState());
} }
private void validateAuthorizationRequest(OAuth2AuthorizationRequestContext authorizationRequestContext) { private void validateAuthorizationRequest(OAuth2AuthorizationRequestContext authorizationRequestContext) {
@ -389,11 +399,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
} }
private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response, private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response,
String redirectUri, String code, String state) throws IOException { String redirectUri, OAuth2AuthorizationCode authorizationCode, String state) throws IOException {
UriComponentsBuilder uriBuilder = UriComponentsBuilder UriComponentsBuilder uriBuilder = UriComponentsBuilder
.fromUriString(redirectUri) .fromUriString(redirectUri)
.queryParam(OAuth2ParameterNames.CODE, code); .queryParam(OAuth2ParameterNames.CODE, authorizationCode.getTokenValue());
if (StringUtils.hasText(state)) { if (StringUtils.hasText(state)) {
uriBuilder.queryParam(OAuth2ParameterNames.STATE, state); uriBuilder.queryParam(OAuth2ParameterNames.STATE, state);
} }

View File

@ -34,13 +34,13 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -153,7 +153,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -167,7 +167,7 @@ public class OAuth2AuthorizationCodeGrantTests {
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken( verify(authorizationService).findByToken(
eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE)); eq(TokenType.AUTHORIZATION_CODE));
verify(authorizationService).save(any()); verify(authorizationService).save(any());
} }
@ -199,7 +199,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization = authorizationCaptor.getValue(); OAuth2Authorization authorization = authorizationCaptor.getValue();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -212,7 +212,7 @@ public class OAuth2AuthorizationCodeGrantTests {
verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService, times(2)).findByToken( verify(authorizationService, times(2)).findByToken(
eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE)); eq(TokenType.AUTHORIZATION_CODE));
verify(authorizationService, times(2)).save(any()); verify(authorizationService, times(2)).save(any());
} }
@ -232,7 +232,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization) { OAuth2Authorization authorization) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(); MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
parameters.set(OAuth2ParameterNames.CODE, authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); parameters.set(OAuth2ParameterNames.CODE, authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue());
parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
return parameters; return parameters;
} }

View File

@ -20,9 +20,11 @@ import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -36,7 +38,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class InMemoryOAuth2AuthorizationServiceTests { public class InMemoryOAuth2AuthorizationServiceTests {
private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
private static final String PRINCIPAL_NAME = "principal"; private static final String PRINCIPAL_NAME = "principal";
private static final String AUTHORIZATION_CODE = "code"; private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode(
"code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES));
private InMemoryOAuth2AuthorizationService authorizationService; private InMemoryOAuth2AuthorizationService authorizationService;
@Before @Before
@ -55,12 +58,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void saveWhenAuthorizationProvidedThenSaved() { public void saveWhenAuthorizationProvidedThenSaved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
.build(); .build();
this.authorizationService.save(expectedAuthorization); this.authorizationService.save(expectedAuthorization);
OAuth2Authorization authorization = this.authorizationService.findByToken( OAuth2Authorization authorization = this.authorizationService.findByToken(
AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
assertThat(authorization).isEqualTo(expectedAuthorization); assertThat(authorization).isEqualTo(expectedAuthorization);
} }
@ -75,17 +78,17 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void removeWhenAuthorizationProvidedThenRemoved() { public void removeWhenAuthorizationProvidedThenRemoved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
.build(); .build();
this.authorizationService.save(expectedAuthorization); this.authorizationService.save(expectedAuthorization);
OAuth2Authorization authorization = this.authorizationService.findByToken( OAuth2Authorization authorization = this.authorizationService.findByToken(
AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
assertThat(authorization).isEqualTo(expectedAuthorization); assertThat(authorization).isEqualTo(expectedAuthorization);
this.authorizationService.remove(expectedAuthorization); this.authorizationService.remove(expectedAuthorization);
authorization = this.authorizationService.findByToken( authorization = this.authorizationService.findByToken(
AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
assertThat(authorization).isNull(); assertThat(authorization).isNull();
} }
@ -114,12 +117,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void findByTokenWhenTokenTypeAuthorizationCodeThenFound() { public void findByTokenWhenTokenTypeAuthorizationCodeThenFound() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
OAuth2Authorization result = this.authorizationService.findByToken( OAuth2Authorization result = this.authorizationService.findByToken(
AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
assertThat(authorization).isEqualTo(result); assertThat(authorization).isEqualTo(result);
} }
@ -129,8 +132,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
"access-token", Instant.now().minusSeconds(60), Instant.now()); "access-token", Instant.now().minusSeconds(60), Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(accessToken).build())
.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);

View File

@ -19,13 +19,14 @@ import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.data.MapEntry.entry;
/** /**
* Tests for {@link OAuth2Authorization}. * Tests for {@link OAuth2Authorization}.
@ -38,7 +39,8 @@ public class OAuth2AuthorizationTests {
private static final String PRINCIPAL_NAME = "principal"; private static final String PRINCIPAL_NAME = "principal";
private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken( private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
private static final String AUTHORIZATION_CODE = "code"; private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode(
"code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES));
@Test @Test
public void withRegisteredClientWhenRegisteredClientNullThenThrowIllegalArgumentException() { public void withRegisteredClientWhenRegisteredClientNullThenThrowIllegalArgumentException() {
@ -58,14 +60,15 @@ public class OAuth2AuthorizationTests {
public void fromWhenAuthorizationProvidedThenCopied() { public void fromWhenAuthorizationProvidedThenCopied() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build()) .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build())
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
.build(); .build();
OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build(); OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
assertThat(authorizationResult.getTokens().getToken(OAuth2AuthorizationCode.class))
.isEqualTo(authorization.getTokens().getToken(OAuth2AuthorizationCode.class));
assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes()); assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
} }
@ -98,14 +101,12 @@ public class OAuth2AuthorizationTests {
public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build()) .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build())
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
.build(); .build();
assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
assertThat(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isEqualTo(AUTHORIZATION_CODE);
assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN); assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN);
assertThat(authorization.getAttributes()).containsExactly(
entry(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE));
} }
} }

View File

@ -19,6 +19,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant; import java.time.Instant;
@ -41,6 +42,8 @@ public class TestOAuth2Authorizations {
public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
Map<String, Object> authorizationRequestAdditionalParameters) { Map<String, Object> authorizationRequestAdditionalParameters) {
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
"code", Instant.now(), Instant.now().plusSeconds(120));
OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
@ -53,8 +56,7 @@ public class TestOAuth2Authorizations {
.build(); .build();
return OAuth2Authorization.withRegisteredClient(registeredClient) return OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName("principal") .principalName("principal")
.tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) .tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).build())
.attribute(OAuth2AuthorizationAttributeNames.CODE, "code")
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()); .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes());
} }

View File

@ -37,6 +37,8 @@ import org.springframework.security.oauth2.server.authorization.client.InMemoryR
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
@ -153,6 +155,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode") .extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class);
assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue();
} }
@Test @Test
@ -173,6 +181,30 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
} }
@Test
public void authenticateWhenInvalidatedCodeThenThrowOAuth2AuthenticationException() {
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120));
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization()
.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
.build();
authorization.getTokens().invalidate(authorizationCode);
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test @Test
public void authenticateWhenValidCodeThenReturnAccessToken() { public void authenticateWhenValidCodeThenReturnAccessToken() {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
@ -203,8 +235,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken());
OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class);
assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue();
} }
private static Jwt createJwt() { private static Jwt createJwt() {

View File

@ -94,6 +94,35 @@ public class OAuth2TokensTests {
.hasMessage("token cannot be null"); .hasMessage("token cannot be null");
} }
@Test
public void fromWhenTokensNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.from(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("tokens cannot be null");
}
@Test
public void fromWhenTokensProvidedThenCopied() {
OAuth2Tokens tokens = OAuth2Tokens.builder()
.accessToken(this.accessToken)
.refreshToken(this.refreshToken)
.token(this.idToken)
.build();
OAuth2Tokens tokensResult = OAuth2Tokens.from(tokens).build();
assertThat(tokensResult.getAccessToken()).isEqualTo(tokens.getAccessToken());
assertThat(tokensResult.getTokenMetadata(tokensResult.getAccessToken()))
.isEqualTo(tokens.getTokenMetadata(tokens.getAccessToken()));
assertThat(tokensResult.getRefreshToken()).isEqualTo(tokens.getRefreshToken());
assertThat(tokensResult.getTokenMetadata(tokensResult.getRefreshToken()))
.isEqualTo(tokens.getTokenMetadata(tokens.getRefreshToken()));
assertThat(tokensResult.getToken(OidcIdToken.class)).isEqualTo(tokens.getToken(OidcIdToken.class));
assertThat(tokensResult.getTokenMetadata(tokensResult.getToken(OidcIdToken.class)))
.isEqualTo(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)));
}
@Test @Test
public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() { public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() {
OAuth2Tokens tokens = OAuth2Tokens.builder() OAuth2Tokens tokens = OAuth2Tokens.builder()

View File

@ -40,6 +40,7 @@ import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
@ -434,8 +435,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
String code = authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE); OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
assertThat(code).isNotNull(); assertThat(authorizationCode).isNotNull();
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest).isNotNull();
@ -481,8 +482,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
String code = authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE); OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
assertThat(code).isNotNull(); assertThat(authorizationCode).isNotNull();
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest).isNotNull();
@ -755,9 +756,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull(); assertThat(updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isNotNull();
assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.CODE)).isNotNull();
assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)) assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))
.isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)); .isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST));
assertThat(updatedAuthorization.<Set<String>>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES)) assertThat(updatedAuthorization.<Set<String>>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES))