diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java deleted file mode 100644 index 1bf6808..0000000 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization; - -import org.springframework.util.Assert; - -/** - * An {@link OAuth2TokenRevocationService} that revokes tokens. - * - * @author Vivek Babu - * @see OAuth2AuthorizationService - * @since 0.0.1 - */ -public final class DefaultOAuth2TokenRevocationService implements OAuth2TokenRevocationService { - - private OAuth2AuthorizationService authorizationService; - - /** - * Constructs an {@code DefaultOAuth2TokenRevocationService}. - */ - public DefaultOAuth2TokenRevocationService(OAuth2AuthorizationService authorizationService) { - Assert.notNull(authorizationService, "authorizationService cannot be null"); - this.authorizationService = authorizationService; - } - - @Override - public void revoke(String token, TokenType tokenType) { - final OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType(token, tokenType); - if (authorization != null) { - final OAuth2Authorization revokedAuthorization = OAuth2Authorization.from(authorization) - .revoked(true).build(); - this.authorizationService.save(revokedAuthorization); - } - } -} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java deleted file mode 100644 index 7ad02a4..0000000 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization; - -/** - * Implementations of this interface are responsible for the revocation of - * OAuth2 tokens. - * - * @author Vivek Babu - * @since 0.0.1 - */ -public interface OAuth2TokenRevocationService { - - /** - * Revokes the given token. - * - * @param token the token to be revoked - * @param tokenType the type of token to be revoked - */ - void revoke(String token, TokenType tokenType); -} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java deleted file mode 100644 index bd1fb77..0000000 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; - -import java.time.Instant; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -/** - * Tests for {@link DefaultOAuth2TokenRevocationService}. - * - * @author Vivek Babu - */ -public class DefaultOAuth2TokenRevocationServiceTests { - private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); - private static final String PRINCIPAL_NAME = "principal"; - private static final String AUTHORIZATION_CODE = "code"; - private DefaultOAuth2TokenRevocationService revocationService; - private OAuth2AuthorizationService authorizationService; - - @Before - public void setup() { - this.authorizationService = mock(OAuth2AuthorizationService.class); - this.revocationService = new DefaultOAuth2TokenRevocationService(authorizationService); - } - - @Test - public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultOAuth2TokenRevocationService(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationService cannot be null"); - } - - @Test - public void revokeWhenTokenNotFound() { - this.revocationService.revoke("token", TokenType.ACCESS_TOKEN); - verify(authorizationService, times(1)).findByTokenAndTokenType(eq("token"), - eq(TokenType.ACCESS_TOKEN)); - verify(authorizationService, times(0)).save(any()); - } - - @Test - public void revokeWhenTokenFound() { - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", Instant.now().minusSeconds(60), Instant.now()); - OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) - .principalName(PRINCIPAL_NAME) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) - .accessToken(accessToken) - .build(); - when(authorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN))) - .thenReturn(authorization); - this.revocationService.revoke("token", TokenType.ACCESS_TOKEN); - - ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); - verify(this.authorizationService).save(authorizationCaptor.capture()); - final OAuth2Authorization savedAuthorization = authorizationCaptor.getValue(); - assertThat(savedAuthorization.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); - assertThat((String) savedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)) - .isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); - assertThat(savedAuthorization.getAccessToken()).isEqualTo(authorization.getAccessToken()); - assertThat(savedAuthorization.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); - assertThat(savedAuthorization.isRevoked()).isTrue(); - } -} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index 4104c80..80fe069 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -31,11 +31,13 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.web.JwkSetEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientAuthenticationFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; +import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; @@ -73,6 +75,8 @@ public final class OAuth2AuthorizationServerConfigurer getEndpointMatchers() { - return Arrays.asList(this.authorizationEndpointMatcher, - this.tokenEndpointMatcher, this.jwkSetEndpointMatcher); + return Arrays.asList(this.authorizationEndpointMatcher, this.tokenEndpointMatcher, + this.tokenRevocationEndpointMatcher, this.jwkSetEndpointMatcher); } @Override @@ -145,11 +149,17 @@ public final class OAuth2AuthorizationServerConfigurer exceptionHandling = builder.getConfigurer(ExceptionHandlingConfigurer.class); if (exceptionHandling != null) { - // Register the default AuthenticationEntryPoint for the token endpoint + // Register the default AuthenticationEntryPoint for the token endpoint and token revocation endpoint exceptionHandling.defaultAuthenticationEntryPointFor( - new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED), this.tokenEndpointMatcher); + new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED), + new OrRequestMatcher(this.tokenEndpointMatcher, this.tokenRevocationEndpointMatcher)); } } @@ -160,8 +170,10 @@ public final class OAuth2AuthorizationServerConfigurer> RegisteredClientRepository getRegisteredClientRepository(B builder) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/TokenType.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/TokenType.java index 14e586a..0410277 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/TokenType.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/TokenType.java @@ -15,7 +15,6 @@ */ package org.springframework.security.oauth2.server.authorization; -import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.util.Assert; import java.io.Serializable; @@ -26,6 +25,7 @@ import java.io.Serializable; public final class TokenType implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; public static final TokenType ACCESS_TOKEN = new TokenType("access_token"); + public static final TokenType REFRESH_TOKEN = new TokenType("refresh_token"); public static final TokenType AUTHORIZATION_CODE = new TokenType("authorization_code"); private final String value; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java new file mode 100644 index 0000000..8dfc720 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java @@ -0,0 +1,61 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; + +/** + * Utility methods for the OAuth 2.0 {@link AuthenticationProvider}'s. + * + * @author Joe Grandja + * @since 0.0.3 + */ +final class OAuth2AuthenticationProviderUtils { + + private OAuth2AuthenticationProviderUtils() { + } + + static OAuth2Authorization invalidate( + OAuth2Authorization authorization, T token) { + + OAuth2Tokens.Builder builder = OAuth2Tokens.from(authorization.getTokens()) + .token(token, OAuth2TokenMetadata.builder().invalidated().build()); + + if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) { + builder.token( + authorization.getTokens().getAccessToken(), + OAuth2TokenMetadata.builder().invalidated().build()); + OAuth2AuthorizationCode authorizationCode = + authorization.getTokens().getToken(OAuth2AuthorizationCode.class); + if (authorizationCode != null && + !authorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()) { + builder.token( + authorizationCode, + OAuth2TokenMetadata.builder().invalidated().build()); + } + } + + return OAuth2Authorization.from(authorization) + .tokens(builder.build()) + .build(); + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 44652c2..ad6fa61 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -105,14 +105,17 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); + OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) { - // Invalidate the authorization code given that a different client is attempting to use it - authorization.getTokens().invalidate(authorizationCode); - this.authorizationService.save(authorization); + if (!authorizationCodeMetadata.isInvalidated()) { + // Invalidate the authorization code given that a different client is attempting to use it + authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode); + this.authorizationService.save(authorization); + } throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } @@ -121,9 +124,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } - OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode); if (authorizationCodeMetadata.isInvalidated()) { - // Prevent the same client from using the authorization code more than once throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } @@ -154,15 +155,16 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); - OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens()) - .accessToken(accessToken) - .build(); - tokens.invalidate(authorizationCode); // Invalidate the authorization code as it can only be used once - authorization = OAuth2Authorization.from(authorization) - .tokens(tokens) + .tokens(OAuth2Tokens.from(authorization.getTokens()) + .accessToken(accessToken) + .build()) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .build(); + + // Invalidate the authorization code as it can only be used once + authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode); + this.authorizationService.save(authorization); return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken); diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java similarity index 63% rename from core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java rename to oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java index 8e0cb75..7816ed8 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java @@ -18,78 +18,83 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2TokenRevocationService; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** - * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Token Revocation. + * An {@link AuthenticationProvider} implementation for OAuth 2.0 Token Revocation. * * @author Vivek Babu - * @since 0.0.1 + * @author Joe Grandja + * @since 0.0.3 * @see OAuth2TokenRevocationAuthenticationToken * @see OAuth2AuthorizationService - * @see OAuth2TokenRevocationService * @see Section 2.1 Revocation Request */ public class OAuth2TokenRevocationAuthenticationProvider implements AuthenticationProvider { - - private OAuth2AuthorizationService authorizationService; - private OAuth2TokenRevocationService tokenRevocationService; + private final OAuth2AuthorizationService authorizationService; /** * Constructs an {@code OAuth2TokenRevocationAuthenticationProvider} using the provided parameters. * * @param authorizationService the authorization service - * @param tokenRevocationService the token revocation service */ - public OAuth2TokenRevocationAuthenticationProvider(OAuth2AuthorizationService authorizationService, - OAuth2TokenRevocationService tokenRevocationService) { + public OAuth2TokenRevocationAuthenticationProvider(OAuth2AuthorizationService authorizationService) { Assert.notNull(authorizationService, "authorizationService cannot be null"); - Assert.notNull(tokenRevocationService, "tokenRevocationService cannot be null"); this.authorizationService = authorizationService; - this.tokenRevocationService = tokenRevocationService; } @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthenticationToken = + OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthentication = (OAuth2TokenRevocationAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientPrincipal = null; - if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(tokenRevocationAuthenticationToken.getPrincipal() - .getClass())) { - clientPrincipal = (OAuth2ClientAuthenticationToken) tokenRevocationAuthenticationToken.getPrincipal(); + if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(tokenRevocationAuthentication.getPrincipal().getClass())) { + clientPrincipal = (OAuth2ClientAuthenticationToken) tokenRevocationAuthentication.getPrincipal(); } if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } + RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); - final RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); - final String tokenTypeHint = tokenRevocationAuthenticationToken.getTokenTypeHint(); - final String token = tokenRevocationAuthenticationToken.getToken(); - final OAuth2Authorization authorization = authorizationService.findByTokenAndTokenType(token, - TokenType.ACCESS_TOKEN); - - OAuth2TokenRevocationAuthenticationToken successfulAuthentication = - new OAuth2TokenRevocationAuthenticationToken(token, registeredClient, tokenTypeHint); + TokenType tokenType = null; + String tokenTypeHint = tokenRevocationAuthentication.getTokenTypeHint(); + if (StringUtils.hasText(tokenTypeHint)) { + if (TokenType.REFRESH_TOKEN.getValue().equals(tokenTypeHint)) { + tokenType = TokenType.REFRESH_TOKEN; + } else if (TokenType.ACCESS_TOKEN.getValue().equals(tokenTypeHint)) { + tokenType = TokenType.ACCESS_TOKEN; + } else { + // TODO Add OAuth2ErrorCodes.UNSUPPORTED_TOKEN_TYPE + throw new OAuth2AuthenticationException(new OAuth2Error("unsupported_token_type")); + } + } + OAuth2Authorization authorization = this.authorizationService.findByToken( + tokenRevocationAuthentication.getToken(), tokenType); if (authorization == null) { - return successfulAuthentication; + // Return the authentication request when token not found + return tokenRevocationAuthentication; } - if (!registeredClient.getClientId().equals(authorization.getRegisteredClientId())) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } - tokenRevocationService.revoke(token, TokenType.ACCESS_TOKEN); - return successfulAuthentication; + AbstractOAuth2Token token = authorization.getTokens().getToken(tokenRevocationAuthentication.getToken()); + authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token); + this.authorizationService.save(authorization); + + return new OAuth2TokenRevocationAuthenticationToken(token, clientPrincipal); } @Override diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java similarity index 63% rename from core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java rename to oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java index d42bbf8..6e383fe 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java @@ -18,53 +18,64 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.lang.Nullable; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.server.authorization.Version; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.util.Assert; import java.util.Collections; /** - * An {@link Authentication} implementation used for OAuth 2.0 Client Authentication. + * An {@link Authentication} implementation used for OAuth 2.0 Token Revocation. * * @author Vivek Babu - * @since 0.0.1 + * @author Joe Grandja + * @since 0.0.3 * @see AbstractAuthenticationToken - * @see RegisteredClient * @see OAuth2TokenRevocationAuthenticationProvider */ public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private final String token; + private final Authentication clientPrincipal; private final String tokenTypeHint; - private Authentication clientPrincipal; - private String token; - private RegisteredClient registeredClient; + /** + * Constructs an {@code OAuth2TokenRevocationAuthenticationToken} using the provided parameters. + * + * @param token the token + * @param clientPrincipal the authenticated client principal + * @param tokenTypeHint the token type hint + */ public OAuth2TokenRevocationAuthenticationToken(String token, Authentication clientPrincipal, @Nullable String tokenTypeHint) { super(Collections.emptyList()); - Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); Assert.hasText(token, "token cannot be empty"); + Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); this.token = token; this.clientPrincipal = clientPrincipal; this.tokenTypeHint = tokenTypeHint; } - public OAuth2TokenRevocationAuthenticationToken(String token, - RegisteredClient registeredClient, @Nullable String tokenTypeHint) { + /** + * Constructs an {@code OAuth2TokenRevocationAuthenticationToken} using the provided parameters. + * + * @param revokedToken the revoked token + * @param clientPrincipal the authenticated client principal + */ + public OAuth2TokenRevocationAuthenticationToken(AbstractOAuth2Token revokedToken, + Authentication clientPrincipal) { super(Collections.emptyList()); - Assert.notNull(registeredClient, "registeredClient cannot be null"); - Assert.hasText(token, "token cannot be empty"); - this.token = token; - this.registeredClient = registeredClient; - this.tokenTypeHint = tokenTypeHint; - setAuthenticated(true); + Assert.notNull(revokedToken, "revokedToken cannot be null"); + Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); + this.token = revokedToken.getTokenValue(); + this.clientPrincipal = clientPrincipal; + this.tokenTypeHint = null; + setAuthenticated(true); // Indicates that the token was authenticated and revoked } @Override public Object getPrincipal() { - return this.clientPrincipal != null ? this.clientPrincipal : this.registeredClient - .getClientId(); + return this.clientPrincipal; } @Override @@ -86,17 +97,8 @@ public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthentica * * @return the token type hint */ + @Nullable public String getTokenTypeHint() { - return tokenTypeHint; - } - - /** - * Returns the {@link RegisteredClient registered client}. - * - * @return the {@link RegisteredClient} - */ - public @Nullable - RegisteredClient getRegisteredClient() { - return this.registeredClient; + return this.tokenTypeHint; } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java index 819067d..3f91f52 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java @@ -82,6 +82,24 @@ public class OAuth2Tokens implements Serializable { return tokenHolder != null ? (T) tokenHolder.getToken() : null; } + /** + * Returns the token specified by {@code token}. + * + * @param token the token + * @param the type of the token + * @return the token, or {@code null} if not available + */ + @Nullable + @SuppressWarnings("unchecked") + public T getToken(String token) { + Assert.hasText(token, "token cannot be empty"); + OAuth2TokenHolder tokenHolder = this.tokens.values().stream() + .filter(holder -> holder.getToken().getTokenValue().equals(token)) + .findFirst() + .orElse(null); + return tokenHolder != null ? (T) tokenHolder.getToken() : null; + } + /** * Returns the token metadata associated to the provided {@code token}. * @@ -97,29 +115,6 @@ public class OAuth2Tokens implements Serializable { tokenHolder.getTokenMetadata() : null; } - /** - * Invalidates all tokens. - */ - public void invalidate() { - this.tokens.values().forEach(tokenHolder -> invalidate(tokenHolder.getToken())); - } - - /** - * Invalidates the token matching the provided {@code token}. - * - * @param token the token - * @param the type of the token - */ - public void invalidate(T token) { - Assert.notNull(token, "token cannot be null"); - this.tokens.computeIfPresent(token.getClass(), - (tokenType, tokenHolder) -> - new OAuth2TokenHolder( - tokenHolder.getToken(), - OAuth2TokenMetadata.builder().invalidated().build()) - ); - } - @Override public boolean equals(Object obj) { if (this == obj) { diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java similarity index 73% rename from core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java rename to oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java index a8ba267..f87012a 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java @@ -27,10 +27,10 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -43,31 +43,30 @@ import javax.servlet.http.HttpServletResponse; import java.io.IOException; /** - * A {@code Filter} for the OAuth 2.0 Token Revocation, - * which handles the processing of the OAuth 2.0 Token Revocation Request. + * A {@code Filter} for the OAuth 2.0 Token Revocation endpoint. * * @author Vivek Babu - * @see OAuth2AuthorizationService - * @see OAuth2Authorization + * @author Joe Grandja + * @see OAuth2TokenRevocationAuthenticationProvider * @see Section 2 Token Revocation * @see Section 2.1 Revocation Request - * @since 0.0.1 + * @since 0.0.3 */ public class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFilter { + static final String TOKEN_PARAM_NAME = "token"; + static final String TOKEN_TYPE_HINT_PARAM_NAME = "token_type_hint"; /** - * The default endpoint {@code URI} for token revocation request. + * The default endpoint {@code URI} for token revocation requests. */ public static final String DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI = "/oauth2/revoke"; - private static final String TOKEN_TYPE_HINT = "token_type_hint"; - private static final String TOKEN = "token"; - private final AntPathRequestMatcher revocationEndpointMatcher; + private final AuthenticationManager authenticationManager; + private final RequestMatcher tokenRevocationEndpointMatcher; private final Converter tokenRevocationAuthenticationConverter = - new OAuth2TokenRevocationEndpointFilter.TokenRevocationAuthenticationConverter(); + new DefaultTokenRevocationAuthenticationConverter(); private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); - private final AuthenticationManager authenticationManager; /** * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters. @@ -82,30 +81,30 @@ public class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFilter { * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters. * * @param authenticationManager the authentication manager - * @param revocationEndpointUri the endpoint {@code URI} for revocation requests + * @param tokenRevocationEndpointUri the endpoint {@code URI} for token revocation requests */ public OAuth2TokenRevocationEndpointFilter(AuthenticationManager authenticationManager, - String revocationEndpointUri) { + String tokenRevocationEndpointUri) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); - Assert.hasText(revocationEndpointUri, "revocationEndpointUri cannot be empty"); + Assert.hasText(tokenRevocationEndpointUri, "tokenRevocationEndpointUri cannot be empty"); this.authenticationManager = authenticationManager; - this.revocationEndpointMatcher = new AntPathRequestMatcher( - revocationEndpointUri, HttpMethod.POST.name()); + this.tokenRevocationEndpointMatcher = new AntPathRequestMatcher( + tokenRevocationEndpointUri, HttpMethod.POST.name()); } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (!this.revocationEndpointMatcher.matches(request)) { + if (!this.tokenRevocationEndpointMatcher.matches(request)) { filterChain.doFilter(request, response); return; } try { - Authentication tokenRevocationRequestAuthentication = - this.tokenRevocationAuthenticationConverter.convert(request); - this.authenticationManager.authenticate(tokenRevocationRequestAuthentication); + this.authenticationManager.authenticate( + this.tokenRevocationAuthenticationConverter.convert(request)); + response.setStatus(HttpStatus.OK.value()); } catch (OAuth2AuthenticationException ex) { SecurityContextHolder.clearContext(); sendErrorResponse(response, ex.getError()); @@ -118,30 +117,34 @@ public class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFilter { this.errorHttpResponseConverter.write(error, null, httpResponse); } - private static OAuth2AuthenticationException throwError(String errorCode, String parameterName) { - OAuth2Error error = new OAuth2Error(errorCode, "Token Revocation Request Parameter: " + parameterName, + private static void throwError(String errorCode, String parameterName) { + OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Token Revocation Parameter: " + parameterName, "https://tools.ietf.org/html/rfc7009#section-2.1"); throw new OAuth2AuthenticationException(error); } - private static class TokenRevocationAuthenticationConverter implements - Converter { + private static class DefaultTokenRevocationAuthenticationConverter + implements Converter { @Override public Authentication convert(HttpServletRequest request) { - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + // token (REQUIRED) - String token = parameters.getFirst(TOKEN); + String token = parameters.getFirst(TOKEN_PARAM_NAME); if (!StringUtils.hasText(token) || - parameters.get(TOKEN).size() != 1) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, TOKEN); + parameters.get(TOKEN_PARAM_NAME).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, TOKEN_PARAM_NAME); } // token_type_hint (OPTIONAL) - String tokenTypeHint = parameters.getFirst(TOKEN_TYPE_HINT); + String tokenTypeHint = parameters.getFirst(TOKEN_TYPE_HINT_PARAM_NAME); + if (StringUtils.hasText(tokenTypeHint) && + parameters.get(TOKEN_TYPE_HINT_PARAM_NAME).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, TOKEN_TYPE_HINT_PARAM_NAME); + } return new OAuth2TokenRevocationAuthenticationToken(token, clientPrincipal, tokenTypeHint); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java new file mode 100644 index 0000000..23e1d79 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java @@ -0,0 +1,188 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.http.HttpHeaders; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Integration tests for the OAuth 2.0 Token Revocation endpoint. + * + * @author Joe Grandja + */ +public class OAuth2TokenRevocationTests { + private static RegisteredClientRepository registeredClientRepository; + private static OAuth2AuthorizationService authorizationService; + private static KeyManager keyManager; + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Autowired + private MockMvc mvc; + + @BeforeClass + public static void init() { + registeredClientRepository = mock(RegisteredClientRepository.class); + authorizationService = mock(OAuth2AuthorizationService.class); + keyManager = new StaticKeyGeneratingKeyManager(); + } + + @Before + public void setup() { + reset(registeredClientRepository); + reset(authorizationService); + } + + @Test + public void requestWhenRevokeRefreshTokenThenRevoked() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + OAuth2RefreshToken token = authorization.getTokens().getRefreshToken(); + TokenType tokenType = TokenType.REFRESH_TOKEN; + when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization); + + this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI) + .params(getTokenRevocationRequestParameters(token, tokenType)) + .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( + registeredClient.getClientId(), registeredClient.getClientSecret()))) + .andExpect(status().isOk()); + + verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); + verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType)); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue(); + OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); + } + + @Test + public void requestWhenRevokeAccessTokenThenRevoked() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + OAuth2AccessToken token = authorization.getTokens().getAccessToken(); + TokenType tokenType = TokenType.ACCESS_TOKEN; + when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization); + + this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI) + .params(getTokenRevocationRequestParameters(token, tokenType)) + .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( + registeredClient.getClientId(), registeredClient.getClientSecret()))) + .andExpect(status().isOk()); + + verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); + verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType)); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); + OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse(); + } + + private static MultiValueMap getTokenRevocationRequestParameters(AbstractOAuth2Token token, TokenType tokenType) { + MultiValueMap parameters = new LinkedMultiValueMap<>(); + // TODO Use OAuth2ParameterNames + parameters.set("token", token.getTokenValue()); + parameters.set("token_type_hint", tokenType.getValue()); + return parameters; + } + + private static String encodeBasicAuth(String clientId, String secret) throws Exception { + clientId = URLEncoder.encode(clientId, StandardCharsets.UTF_8.name()); + secret = URLEncoder.encode(secret, StandardCharsets.UTF_8.name()); + String credentialsString = clientId + ":" + secret; + byte[] encodedBytes = Base64.getEncoder().encode(credentialsString.getBytes(StandardCharsets.UTF_8)); + return new String(encodedBytes, StandardCharsets.UTF_8); + } + + @EnableWebSecurity + @Import(OAuth2AuthorizationServerConfiguration.class) + static class AuthorizationServerConfiguration { + + @Bean + RegisteredClientRepository registeredClientRepository() { + return registeredClientRepository; + } + + @Bean + OAuth2AuthorizationService authorizationService() { + return authorizationService; + } + + @Bean + KeyManager keyManager() { + return keyManager; + } + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 26fa9e9..cd1436c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -23,6 +24,7 @@ import org.springframework.security.oauth2.server.authorization.token.OAuth2Auth import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.Map; @@ -46,6 +48,8 @@ public class TestOAuth2Authorizations { "code", Instant.now(), Instant.now().plusSeconds(120)); OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken( + "refresh-token", Instant.now(), Instant.now().plus(1, ChronoUnit.HOURS)); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://provider.com/oauth2/authorize") .clientId(registeredClient.getClientId()) @@ -56,7 +60,7 @@ public class TestOAuth2Authorizations { .build(); return OAuth2Authorization.withRegisteredClient(registeredClient) .principalName("principal") - .tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).build()) + .tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build()) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 10d8d7d..c9cd9ce 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -38,6 +38,7 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; @@ -186,9 +187,10 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120)); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization() - .tokens(OAuth2Tokens.builder().token(authorizationCode).build()) + .tokens(OAuth2Tokens.builder() + .token(authorizationCode, OAuth2TokenMetadata.builder().invalidated().build()) + .build()) .build(); - authorization.getTokens().invalidate(authorizationCode); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java similarity index 54% rename from core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java rename to oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java index bec949f..ffdb52c 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java @@ -17,12 +17,14 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.OAuth2TokenRevocationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -30,8 +32,10 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,38 +43,27 @@ import static org.mockito.Mockito.when; * Tests for {@link OAuth2TokenRevocationAuthenticationProvider}. * * @author Vivek Babu + * @author Joe Grandja */ public class OAuth2TokenRevocationAuthenticationProviderTests { private RegisteredClient registeredClient; - private OAuth2AuthorizationService oAuth2AuthorizationService; - private OAuth2TokenRevocationService oAuth2TokenRevocationService; + private OAuth2AuthorizationService authorizationService; private OAuth2TokenRevocationAuthenticationProvider authenticationProvider; @Before public void setUp() { this.registeredClient = TestRegisteredClients.registeredClient().build(); - this.oAuth2AuthorizationService = mock(OAuth2AuthorizationService.class); - this.oAuth2TokenRevocationService = mock(OAuth2TokenRevocationService.class); - this.authenticationProvider = new OAuth2TokenRevocationAuthenticationProvider(oAuth2AuthorizationService, - oAuth2TokenRevocationService); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new OAuth2TokenRevocationAuthenticationProvider(this.authorizationService); } @Test public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(null, - oAuth2TokenRevocationService)) + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizationService cannot be null"); } - @Test - public void constructorWhenRevocationServiceNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(oAuth2AuthorizationService, - null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("tokenRevocationService cannot be null"); - } - @Test public void supportsWhenTypeOAuth2TokenRevocationAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OAuth2TokenRevocationAuthenticationToken.class)).isTrue(); @@ -81,7 +74,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests { TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( - "token", clientPrincipal, "access_token"); + "token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue()); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -92,9 +85,9 @@ public class OAuth2TokenRevocationAuthenticationProviderTests { @Test public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( - this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + this.registeredClient.getClientId(), this.registeredClient.getClientSecret(), null); OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( - "token", clientPrincipal, "access_token"); + "token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue()); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) @@ -103,48 +96,99 @@ public class OAuth2TokenRevocationAuthenticationProviderTests { } @Test - public void authenticateWhenInvalidTokenThenAuthenticate() { + public void authenticateWhenInvalidTokenTypeThenThrowOAuth2AuthenticationException() { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( - "token", clientPrincipal, "access_token"); - OAuth2TokenRevocationAuthenticationToken authenticationResult = - (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); - assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(this.registeredClient.getClientId()); - assertThat(authenticationResult.getRegisteredClient()).isEqualTo(this.registeredClient); - } - - @Test - public void authenticateWhenAuthorizationIssuedToAnotherClientThenThrowOAuth2AuthenticationException() { - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.oAuth2AuthorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN))) - .thenReturn(authorization); - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( - TestRegisteredClients.registeredClient2().build()); - OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( - "token", clientPrincipal, "access_token"); + "token", clientPrincipal, "unsupported_token_type"); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + .isEqualTo("unsupported_token_type"); } @Test - public void authenticateWhenValidAccessTokenThenInvalidateTokenAndAuthenticate() { + public void authenticateWhenInvalidTokenThenNotRevoked() { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( - "token", clientPrincipal, "access_token"); - OAuth2Authorization mockAuthorization = mock(OAuth2Authorization.class); - when(oAuth2AuthorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN))). - thenReturn(mockAuthorization); - when(mockAuthorization.getRegisteredClientId()).thenReturn(this.registeredClient.getClientId()); + "token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue()); OAuth2TokenRevocationAuthenticationToken authenticationResult = (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); - verify(this.oAuth2TokenRevocationService).revoke(eq("token"), eq(TokenType.ACCESS_TOKEN)); + assertThat(authenticationResult.isAuthenticated()).isFalse(); + verify(this.authorizationService, never()).save(any()); + } + @Test + public void authenticateWhenTokenIssuedToAnotherClientThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + TestRegisteredClients.registeredClient2().build()).build(); + when(this.authorizationService.findByToken( + eq("token"), + eq(TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue()); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenValidRefreshTokenThenRevoked() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + this.registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, TokenType.REFRESH_TOKEN.getValue()); + + OAuth2TokenRevocationAuthenticationToken authenticationResult = + (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(this.registeredClient.getClientId()); - assertThat(authenticationResult.getRegisteredClient()).isEqualTo(this.registeredClient); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue(); + OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); + } + + @Test + public void authenticateWhenValidAccessTokenThenRevoked() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + this.registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getAccessToken().getTokenValue()), + eq(TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + authorization.getTokens().getAccessToken().getTokenValue(), clientPrincipal, TokenType.ACCESS_TOKEN.getValue()); + + OAuth2TokenRevocationAuthenticationToken authenticationResult = + (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); + OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); + assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse(); } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java similarity index 60% rename from core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java rename to oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java index f34cfc0..0fdba31 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java @@ -16,10 +16,13 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.junit.Test; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import java.time.Duration; +import java.time.Instant; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -27,62 +30,64 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * Tests for {@link OAuth2TokenRevocationAuthenticationToken}. * * @author Vivek Babu + * @author Joe Grandja */ public class OAuth2TokenRevocationAuthenticationTokenTests { - private OAuth2TokenRevocationAuthenticationToken clientPrincipal = new OAuth2TokenRevocationAuthenticationToken( - "Token", TestRegisteredClients.registeredClient().build(), null); - private RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + private String token = "token"; + private OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + TestRegisteredClients.registeredClient().build()); + private String tokenTypeHint = TokenType.ACCESS_TOKEN.getValue(); + private OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, this.token, + Instant.now(), Instant.now().plus(Duration.ofHours(1))); @Test public void constructorWhenTokenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, - this.clientPrincipal, "hint")) + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, this.clientPrincipal, this.tokenTypeHint)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("token cannot be empty"); } @Test public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken("token", - (Authentication) null, "hint")) + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(this.token, null, this.tokenTypeHint)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientPrincipal cannot be null"); } @Test - public void constructorWhenTokenNullRegisteredClientPresentThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, registeredClient, "hint")) + public void constructorWhenRevokedTokenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, this.clientPrincipal)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("token cannot be empty"); + .hasMessage("revokedToken cannot be null"); } @Test - public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken("token", - (RegisteredClient) null, "hint")) + public void constructorWhenRevokedTokenAndClientPrincipalNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(this.accessToken, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("registeredClient cannot be null"); + .hasMessage("clientPrincipal cannot be null"); } @Test - public void constructorWhenTokenAndClientPrincipalProvidedThenCreated() { + public void constructorWhenTokenProvidedThenCreated() { OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( - "token", this.clientPrincipal, "token_hint"); + this.token, this.clientPrincipal, this.tokenTypeHint); + assertThat(authentication.getToken()).isEqualTo(this.token); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); + assertThat(authentication.getTokenTypeHint()).isEqualTo(this.tokenTypeHint); assertThat(authentication.getCredentials().toString()).isEmpty(); - assertThat(authentication.getToken()).isEqualTo("token"); - assertThat(authentication.getTokenTypeHint()).isEqualTo("token_hint"); assertThat(authentication.isAuthenticated()).isFalse(); } @Test - public void constructorWhenTokenAndRegisteredProvidedThenCreated() { + public void constructorWhenRevokedTokenProvidedThenCreated() { OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( - "token", this.registeredClient, "token_hint"); - assertThat(authentication.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + this.accessToken, this.clientPrincipal); + assertThat(authentication.getToken()).isEqualTo(this.accessToken.getTokenValue()); + assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); + assertThat(authentication.getTokenTypeHint()).isNull(); assertThat(authentication.getCredentials().toString()).isEmpty(); - assertThat(authentication.getToken()).isEqualTo("token"); - assertThat(authentication.getTokenTypeHint()).isEqualTo("token_hint"); assertThat(authentication.isAuthenticated()).isTrue(); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java index 533fcc5..a795c03 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java @@ -82,11 +82,18 @@ public class OAuth2TokensTests { @Test public void getTokenWhenTokenTypeNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken(null)) + assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((Class) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("tokenType cannot be null"); } + @Test + public void getTokenWhenTokenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((String) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be empty"); + } + @Test public void getTokenMetadataWhenTokenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> OAuth2Tokens.builder().build().getTokenMetadata(null)) @@ -185,32 +192,4 @@ public class OAuth2TokensTests { this.accessToken.getScopes()); assertThat(tokens.getTokenMetadata(otherAccessToken)).isNull(); } - - @Test - public void invalidateWhenAllTokensThenAllInvalidated() { - OAuth2Tokens tokens = OAuth2Tokens.builder() - .accessToken(this.accessToken) - .refreshToken(this.refreshToken) - .token(this.idToken) - .build(); - tokens.invalidate(); - - assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue(); - assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isTrue(); - assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isTrue(); - } - - @Test - public void invalidateWhenTokenProvidedThenInvalidated() { - OAuth2Tokens tokens = OAuth2Tokens.builder() - .accessToken(this.accessToken) - .refreshToken(this.refreshToken) - .token(this.idToken) - .build(); - tokens.invalidate(this.accessToken); - - assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue(); - assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isFalse(); - assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isFalse(); - } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java similarity index 67% rename from core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java rename to oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java index 101feb8..ee7b11a 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java @@ -18,7 +18,6 @@ package org.springframework.security.oauth2.server.authorization.web; import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.mock.http.client.MockClientHttpResponse; @@ -28,9 +27,11 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -39,6 +40,10 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; @@ -48,15 +53,16 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter.TOKEN_PARAM_NAME; +import static org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter.TOKEN_TYPE_HINT_PARAM_NAME; /** * Tests for {@link OAuth2TokenRevocationEndpointFilter}. * * @author Vivek Babu + * @author Joe Grandja */ public class OAuth2TokenRevocationEndpointFilterTests { - private static final String TOKEN = "token"; - private static final String TOKEN_TYPE_HINT = "token_type_hint"; private AuthenticationManager authenticationManager; private OAuth2TokenRevocationEndpointFilter filter; private final HttpMessageConverter errorHttpResponseConverter = @@ -81,14 +87,14 @@ public class OAuth2TokenRevocationEndpointFilterTests { } @Test - public void constructorWhenTokenEndpointUriNullThenThrowIllegalArgumentException() { + public void constructorWhenTokenRevocationEndpointUriNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2TokenRevocationEndpointFilter(this.authenticationManager, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("revocationEndpointUri cannot be empty"); + .hasMessage("tokenRevocationEndpointUri cannot be empty"); } @Test - public void doFilterWhenNotRevocationRequestThenNotProcessed() throws Exception { + public void doFilterWhenNotTokenRevocationRequestThenNotProcessed() throws Exception { String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); @@ -101,8 +107,8 @@ public class OAuth2TokenRevocationEndpointFilterTests { } @Test - public void doFilterWhenRevocationRequestGetThenNotProcessed() throws Exception { - String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; + public void doFilterWhenTokenRevocationRequestGetThenNotProcessed() throws Exception { + String requestUri = OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -114,54 +120,63 @@ public class OAuth2TokenRevocationEndpointFilterTests { } @Test - public void doFilterWhenRevocationRequestMissingTokenThenInvalidRequestError() throws Exception { - doFilterWhenRevocationRequestInvalidParameterThenError( - TOKEN, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.removeParameter(TOKEN)); + public void doFilterWhenTokenRevocationRequestMissingTokenThenInvalidRequestError() throws Exception { + doFilterWhenTokenRevocationRequestInvalidParameterThenError( + TOKEN_PARAM_NAME, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(TOKEN_PARAM_NAME)); } @Test - public void doFilterWhenRevocationRequestMultipleTokenThenInvalidRequestError() throws Exception { - doFilterWhenRevocationRequestInvalidParameterThenError( - TOKEN, OAuth2ErrorCodes.INVALID_REQUEST, - request -> { - request.addParameter(TOKEN, "token-1"); - request.addParameter(TOKEN, "token-2"); - }); + public void doFilterWhenTokenRevocationRequestMultipleTokenThenInvalidRequestError() throws Exception { + doFilterWhenTokenRevocationRequestInvalidParameterThenError( + TOKEN_PARAM_NAME, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(TOKEN_PARAM_NAME, "token-2")); } @Test - public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception { + public void doFilterWhenTokenRevocationRequestMultipleTokenTypeHintThenInvalidRequestError() throws Exception { + doFilterWhenTokenRevocationRequestInvalidParameterThenError( + TOKEN_TYPE_HINT_PARAM_NAME, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(TOKEN_TYPE_HINT_PARAM_NAME, TokenType.ACCESS_TOKEN.getValue())); + } + + @Test + public void doFilterWhenTokenRevocationRequestValidThenSuccessResponse() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofHours(1)), + new HashSet<>(Arrays.asList("scope1", "scope2"))); + OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthentication = + new OAuth2TokenRevocationAuthenticationToken( + accessToken, clientPrincipal); - Authentication tokenRevocationAuthenticationSuccess = mock(Authentication.class); - - when(this.authenticationManager.authenticate(any())).thenReturn(tokenRevocationAuthenticationSuccess); + when(this.authenticationManager.authenticate(any())).thenReturn(tokenRevocationAuthentication); SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); securityContext.setAuthentication(clientPrincipal); SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createRevocationRequest(); + MockHttpServletRequest request = createTokenRevocationRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); verifyNoInteractions(filterChain); - - ArgumentCaptor tokenRevocationAuthenticationCaptor = - ArgumentCaptor.forClass(OAuth2TokenRevocationAuthenticationToken.class); - verify(this.authenticationManager).authenticate(tokenRevocationAuthenticationCaptor.capture()); + verify(this.authenticationManager).authenticate(any()); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); } - private void doFilterWhenRevocationRequestInvalidParameterThenError(String parameterName, String errorCode, + private void doFilterWhenTokenRevocationRequestInvalidParameterThenError(String parameterName, String errorCode, Consumer requestConsumer) throws Exception { - MockHttpServletRequest request = createRevocationRequest(); + MockHttpServletRequest request = createTokenRevocationRequest(); requestConsumer.accept(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -173,7 +188,7 @@ public class OAuth2TokenRevocationEndpointFilterTests { assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); OAuth2Error error = readError(response); assertThat(error.getErrorCode()).isEqualTo(errorCode); - assertThat(error.getDescription()).isEqualTo("Token Revocation Request Parameter: " + parameterName); + assertThat(error.getDescription()).isEqualTo("OAuth 2.0 Token Revocation Parameter: " + parameterName); } private OAuth2Error readError(MockHttpServletResponse response) throws Exception { @@ -182,14 +197,13 @@ public class OAuth2TokenRevocationEndpointFilterTests { return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); } - private static MockHttpServletRequest createRevocationRequest() { - + private static MockHttpServletRequest createTokenRevocationRequest() { String requestUri = OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); - request.addParameter(TOKEN, "token"); - request.addParameter(TOKEN_TYPE_HINT, "access_token"); + request.addParameter(TOKEN_PARAM_NAME, "token"); + request.addParameter(TOKEN_TYPE_HINT_PARAM_NAME, TokenType.ACCESS_TOKEN.getValue()); return request; }