From 18f8b3afaae3103bbc1ff71afb1a0a6e2808dd0a Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 22 Oct 2020 14:03:24 -0400 Subject: [PATCH] Enforce one-time use for authorization code Closes gh-138 --- .../InMemoryOAuth2AuthorizationService.java | 4 +- .../authorization/OAuth2Authorization.java | 2 +- .../OAuth2AuthorizationAttributeNames.java | 1 + ...thorizationCodeAuthenticationProvider.java | 19 +++++++- .../token/OAuth2AuthorizationCode.java | 43 +++++++++++++++++++ .../authorization/token/OAuth2Tokens.java | 18 +++++++- .../OAuth2AuthorizationEndpointFilter.java | 26 +++++++---- .../OAuth2AuthorizationCodeGrantTests.java | 12 +++--- ...MemoryOAuth2AuthorizationServiceTests.java | 22 +++++----- .../OAuth2AuthorizationTests.java | 17 ++++---- .../TestOAuth2Authorizations.java | 6 ++- ...zationCodeAuthenticationProviderTests.java | 35 ++++++++++++++- .../token/OAuth2TokensTests.java | 29 +++++++++++++ ...Auth2AuthorizationEndpointFilterTests.java | 12 +++--- 14 files changed, 201 insertions(+), 45 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2AuthorizationCode.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 4e77698..86e317e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization; import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.util.Assert; import java.io.Serializable; @@ -63,7 +64,8 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) { return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)); } else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { - return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); + OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); + return authorizationCode != null && authorizationCode.getTokenValue().equals(token); } else if (TokenType.ACCESS_TOKEN.equals(tokenType)) { return authorization.getTokens().getAccessToken() != null && authorization.getTokens().getAccessToken().getTokenValue().equals(token); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index a39e90a..a439124 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -152,7 +152,7 @@ public class OAuth2Authorization implements Serializable { Assert.notNull(authorization, "authorization cannot be null"); return new Builder(authorization.getRegisteredClientId()) .principalName(authorization.getPrincipalName()) - .tokens(authorization.getTokens()) + .tokens(OAuth2Tokens.from(authorization.getTokens()).build()) .attributes(attrs -> attrs.putAll(authorization.getAttributes())); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java index 06440b0..364070c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java @@ -38,6 +38,7 @@ public interface OAuth2AuthorizationAttributeNames { /** * The name of the attribute used for the {@link OAuth2ParameterNames#CODE} parameter. */ + @Deprecated String CODE = OAuth2Authorization.class.getName().concat(".CODE"); /** diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index e5153e2..44652c2 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -35,6 +35,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -102,11 +104,15 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica if (authorization == null) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } + OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) { + // Invalidate the authorization code given that a different client is attempting to use it + authorization.getTokens().invalidate(authorizationCode); + this.authorizationService.save(authorization); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } @@ -115,6 +121,12 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } + OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode); + if (authorizationCodeMetadata.isInvalidated()) { + // Prevent the same client from using the authorization code more than once + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); // TODO Allow configuration for issuer claim @@ -142,9 +154,14 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens()) + .accessToken(accessToken) + .build(); + tokens.invalidate(authorizationCode); // Invalidate the authorization code as it can only be used once + authorization = OAuth2Authorization.from(authorization) + .tokens(tokens) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) - .tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2AuthorizationCode.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2AuthorizationCode.java new file mode 100644 index 0000000..e367465 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2AuthorizationCode.java @@ -0,0 +1,43 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import org.springframework.security.oauth2.core.AbstractOAuth2Token; + +import java.time.Instant; + +/** + * An implementation of an {@link AbstractOAuth2Token} + * representing an OAuth 2.0 Authorization Code Grant. + * + * @author Joe Grandja + * @since 0.0.3 + * @see AbstractOAuth2Token + * @see Section 4.1 Authorization Code Grant + */ +public class OAuth2AuthorizationCode extends AbstractOAuth2Token { + + /** + * Constructs an {@code OAuth2AuthorizationCode} using the provided parameters. + * @param tokenValue the token value + * @param issuedAt the time at which the token was issued + * @param expiresAt the time at which the token expires + */ + public OAuth2AuthorizationCode(String tokenValue, Instant issuedAt, Instant expiresAt) { + super(tokenValue, issuedAt, expiresAt); + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java index 3da46e2..819067d 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java @@ -146,14 +146,30 @@ public class OAuth2Tokens implements Serializable { return new Builder(); } + /** + * Returns a new {@link Builder}, initialized with the values from the provided {@code tokens}. + * + * @param tokens the tokens used for initializing the {@link Builder} + * @return the {@link Builder} + */ + public static Builder from(OAuth2Tokens tokens) { + Assert.notNull(tokens, "tokens cannot be null"); + return new Builder(tokens.tokens); + } + /** * A builder for {@link OAuth2Tokens}. */ public static class Builder implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; - private final Map, OAuth2TokenHolder> tokens = new HashMap<>(); + private Map, OAuth2TokenHolder> tokens; protected Builder() { + this.tokens = new HashMap<>(); + } + + protected Builder(Map, OAuth2TokenHolder> tokens) { + this.tokens = new HashMap<>(tokens); } /** diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 06555ea..c3e634b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -36,6 +36,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -53,6 +55,8 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -184,9 +188,12 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { UserConsentPage.displayConsent(request, response, registeredClient, authorization); } else { - String code = this.codeGenerator.generateKey(); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES); // TODO Allow configuration for authorization code time-to-live + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + this.codeGenerator.generateKey(), issuedAt, expiresAt); OAuth2Authorization authorization = builder - .attribute(OAuth2AuthorizationAttributeNames.CODE, code) + .tokens(OAuth2Tokens.builder().token(authorizationCode).build()) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()) .build(); this.authorizationService.save(authorization); @@ -200,7 +207,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { // The authorization code is bound to the client identifier and redirection URI. sendAuthorizationResponse(request, response, - authorizationRequestContext.resolveRedirectUri(), code, authorizationRequest.getState()); + authorizationRequestContext.resolveRedirectUri(), authorizationCode, authorizationRequest.getState()); } } @@ -232,18 +239,21 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { return; } - String code = this.codeGenerator.generateKey(); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES); // TODO Allow configuration for authorization code time-to-live + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + this.codeGenerator.generateKey(), issuedAt, expiresAt); OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization()) + .tokens(OAuth2Tokens.builder().token(authorizationCode).build()) .attributes(attrs -> { attrs.remove(OAuth2AuthorizationAttributeNames.STATE); - attrs.put(OAuth2AuthorizationAttributeNames.CODE, code); attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes()); }) .build(); this.authorizationService.save(authorization); sendAuthorizationResponse(request, response, userConsentRequestContext.resolveRedirectUri(), - code, userConsentRequestContext.getAuthorizationRequest().getState()); + authorizationCode, userConsentRequestContext.getAuthorizationRequest().getState()); } private void validateAuthorizationRequest(OAuth2AuthorizationRequestContext authorizationRequestContext) { @@ -389,11 +399,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { } private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response, - String redirectUri, String code, String state) throws IOException { + String redirectUri, OAuth2AuthorizationCode authorizationCode, String state) throws IOException { UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(redirectUri) - .queryParam(OAuth2ParameterNames.CODE, code); + .queryParam(OAuth2ParameterNames.CODE, authorizationCode.getTokenValue()); if (StringUtils.hasText(state)) { uriBuilder.queryParam(OAuth2ParameterNames.STATE, state); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 85382d8..19711a8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -34,13 +34,13 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; @@ -153,7 +153,7 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); when(authorizationService.findByToken( - eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), + eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); @@ -167,7 +167,7 @@ public class OAuth2AuthorizationCodeGrantTests { verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService).findByToken( - eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), + eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(TokenType.AUTHORIZATION_CODE)); verify(authorizationService).save(any()); } @@ -199,7 +199,7 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2Authorization authorization = authorizationCaptor.getValue(); when(authorizationService.findByToken( - eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), + eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); @@ -212,7 +212,7 @@ public class OAuth2AuthorizationCodeGrantTests { verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService, times(2)).findByToken( - eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), + eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(TokenType.AUTHORIZATION_CODE)); verify(authorizationService, times(2)).save(any()); } @@ -232,7 +232,7 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2Authorization authorization) { MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - parameters.set(OAuth2ParameterNames.CODE, authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); + parameters.set(OAuth2ParameterNames.CODE, authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()); parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); return parameters; } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index 603c7f8..0984450 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -20,9 +20,11 @@ import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; +import java.time.temporal.ChronoUnit; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -36,7 +38,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; public class InMemoryOAuth2AuthorizationServiceTests { private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final String PRINCIPAL_NAME = "principal"; - private static final String AUTHORIZATION_CODE = "code"; + private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode( + "code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES)); private InMemoryOAuth2AuthorizationService authorizationService; @Before @@ -55,12 +58,12 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void saveWhenAuthorizationProvidedThenSaved() { OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build()) .build(); this.authorizationService.save(expectedAuthorization); OAuth2Authorization authorization = this.authorizationService.findByToken( - AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); + AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE); assertThat(authorization).isEqualTo(expectedAuthorization); } @@ -75,17 +78,17 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void removeWhenAuthorizationProvidedThenRemoved() { OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build()) .build(); this.authorizationService.save(expectedAuthorization); OAuth2Authorization authorization = this.authorizationService.findByToken( - AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); + AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE); assertThat(authorization).isEqualTo(expectedAuthorization); this.authorizationService.remove(expectedAuthorization); authorization = this.authorizationService.findByToken( - AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); + AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE); assertThat(authorization).isNull(); } @@ -114,12 +117,12 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void findByTokenWhenTokenTypeAuthorizationCodeThenFound() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build()) .build(); this.authorizationService.save(authorization); OAuth2Authorization result = this.authorizationService.findByToken( - AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); + AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE); assertThat(authorization).isEqualTo(result); } @@ -129,8 +132,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { "access-token", Instant.now().minusSeconds(60), Instant.now()); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) - .tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) + .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(accessToken).build()) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index 2e5a78f..717dec8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -19,13 +19,14 @@ import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; +import java.time.temporal.ChronoUnit; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.data.MapEntry.entry; /** * Tests for {@link OAuth2Authorization}. @@ -38,7 +39,8 @@ public class OAuth2AuthorizationTests { private static final String PRINCIPAL_NAME = "principal"; private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); - private static final String AUTHORIZATION_CODE = "code"; + private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode( + "code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES)); @Test public void withRegisteredClientWhenRegisteredClientNullThenThrowIllegalArgumentException() { @@ -58,14 +60,15 @@ public class OAuth2AuthorizationTests { public void fromWhenAuthorizationProvidedThenCopied() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build()) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build()) .build(); OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build(); assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); + assertThat(authorizationResult.getTokens().getToken(OAuth2AuthorizationCode.class)) + .isEqualTo(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)); assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes()); } @@ -98,14 +101,12 @@ public class OAuth2AuthorizationTests { public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build()) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build()) .build(); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); + assertThat(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isEqualTo(AUTHORIZATION_CODE); assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN); - assertThat(authorization.getAttributes()).containsExactly( - entry(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 48a85b8..26fa9e9 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -19,6 +19,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; @@ -41,6 +42,8 @@ public class TestOAuth2Authorizations { public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, Map authorizationRequestAdditionalParameters) { + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + "code", Instant.now(), Instant.now().plusSeconds(120)); OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() @@ -53,8 +56,7 @@ public class TestOAuth2Authorizations { .build(); return OAuth2Authorization.withRegisteredClient(registeredClient) .principalName("principal") - .tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) - .attribute(OAuth2AuthorizationAttributeNames.CODE, "code") + .tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).build()) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 22be631..10d8d7d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -37,6 +37,8 @@ import org.springframework.security.oauth2.server.authorization.client.InMemoryR import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -153,6 +155,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue(); } @Test @@ -173,6 +181,30 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); } + @Test + public void authenticateWhenInvalidatedCodeThenThrowOAuth2AuthenticationException() { + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120)); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization() + .tokens(OAuth2Tokens.builder().token(authorizationCode).build()) + .build(); + authorization.getTokens().invalidate(authorizationCode); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + @Test public void authenticateWhenValidCodeThenReturnAccessToken() { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); @@ -203,8 +235,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); - assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull(); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); + OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue(); } private static Jwt createJwt() { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java index 5a94065..533fcc5 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java @@ -94,6 +94,35 @@ public class OAuth2TokensTests { .hasMessage("token cannot be null"); } + @Test + public void fromWhenTokensNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Tokens.from(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("tokens cannot be null"); + } + + @Test + public void fromWhenTokensProvidedThenCopied() { + OAuth2Tokens tokens = OAuth2Tokens.builder() + .accessToken(this.accessToken) + .refreshToken(this.refreshToken) + .token(this.idToken) + .build(); + OAuth2Tokens tokensResult = OAuth2Tokens.from(tokens).build(); + + assertThat(tokensResult.getAccessToken()).isEqualTo(tokens.getAccessToken()); + assertThat(tokensResult.getTokenMetadata(tokensResult.getAccessToken())) + .isEqualTo(tokens.getTokenMetadata(tokens.getAccessToken())); + + assertThat(tokensResult.getRefreshToken()).isEqualTo(tokens.getRefreshToken()); + assertThat(tokensResult.getTokenMetadata(tokensResult.getRefreshToken())) + .isEqualTo(tokens.getTokenMetadata(tokens.getRefreshToken())); + + assertThat(tokensResult.getToken(OidcIdToken.class)).isEqualTo(tokens.getToken(OidcIdToken.class)); + assertThat(tokensResult.getTokenMetadata(tokensResult.getToken(OidcIdToken.class))) + .isEqualTo(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class))); + } + @Test public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() { OAuth2Tokens tokens = OAuth2Tokens.builder() diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 1e64dda..86d9afc 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -40,6 +40,7 @@ import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.util.StringUtils; import javax.servlet.FilterChain; @@ -434,8 +435,8 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); - String code = authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE); - assertThat(code).isNotNull(); + OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); + assertThat(authorizationCode).isNotNull(); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); assertThat(authorizationRequest).isNotNull(); @@ -481,8 +482,8 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); - String code = authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE); - assertThat(code).isNotNull(); + OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); + assertThat(authorizationCode).isNotNull(); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); assertThat(authorizationRequest).isNotNull(); @@ -755,9 +756,8 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); - assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull(); + assertThat(updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isNotNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); - assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)).isNotNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)) .isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)); assertThat(updatedAuthorization.>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES))