diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java new file mode 100644 index 0000000..1bf6808 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java @@ -0,0 +1,48 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.springframework.util.Assert; + +/** + * An {@link OAuth2TokenRevocationService} that revokes tokens. + * + * @author Vivek Babu + * @see OAuth2AuthorizationService + * @since 0.0.1 + */ +public final class DefaultOAuth2TokenRevocationService implements OAuth2TokenRevocationService { + + private OAuth2AuthorizationService authorizationService; + + /** + * Constructs an {@code DefaultOAuth2TokenRevocationService}. + */ + public DefaultOAuth2TokenRevocationService(OAuth2AuthorizationService authorizationService) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.authorizationService = authorizationService; + } + + @Override + public void revoke(String token, TokenType tokenType) { + final OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType(token, tokenType); + if (authorization != null) { + final OAuth2Authorization revokedAuthorization = OAuth2Authorization.from(authorization) + .revoked(true).build(); + this.authorizationService.save(revokedAuthorization); + } + } +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java new file mode 100644 index 0000000..7ad02a4 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java @@ -0,0 +1,34 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +/** + * Implementations of this interface are responsible for the revocation of + * OAuth2 tokens. + * + * @author Vivek Babu + * @since 0.0.1 + */ +public interface OAuth2TokenRevocationService { + + /** + * Revokes the given token. + * + * @param token the token to be revoked + * @param tokenType the type of token to be revoked + */ + void revoke(String token, TokenType tokenType); +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java new file mode 100644 index 0000000..8e0cb75 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java @@ -0,0 +1,99 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenRevocationService; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.util.Assert; + +/** + * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Token Revocation. + * + * @author Vivek Babu + * @since 0.0.1 + * @see OAuth2TokenRevocationAuthenticationToken + * @see OAuth2AuthorizationService + * @see OAuth2TokenRevocationService + * @see Section 2.1 Revocation Request + */ +public class OAuth2TokenRevocationAuthenticationProvider implements AuthenticationProvider { + + private OAuth2AuthorizationService authorizationService; + private OAuth2TokenRevocationService tokenRevocationService; + + /** + * Constructs an {@code OAuth2TokenRevocationAuthenticationProvider} using the provided parameters. + * + * @param authorizationService the authorization service + * @param tokenRevocationService the token revocation service + */ + public OAuth2TokenRevocationAuthenticationProvider(OAuth2AuthorizationService authorizationService, + OAuth2TokenRevocationService tokenRevocationService) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(tokenRevocationService, "tokenRevocationService cannot be null"); + this.authorizationService = authorizationService; + this.tokenRevocationService = tokenRevocationService; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthenticationToken = + (OAuth2TokenRevocationAuthenticationToken) authentication; + + OAuth2ClientAuthenticationToken clientPrincipal = null; + if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(tokenRevocationAuthenticationToken.getPrincipal() + .getClass())) { + clientPrincipal = (OAuth2ClientAuthenticationToken) tokenRevocationAuthenticationToken.getPrincipal(); + } + if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } + + final RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); + final String tokenTypeHint = tokenRevocationAuthenticationToken.getTokenTypeHint(); + final String token = tokenRevocationAuthenticationToken.getToken(); + final OAuth2Authorization authorization = authorizationService.findByTokenAndTokenType(token, + TokenType.ACCESS_TOKEN); + + OAuth2TokenRevocationAuthenticationToken successfulAuthentication = + new OAuth2TokenRevocationAuthenticationToken(token, registeredClient, tokenTypeHint); + + if (authorization == null) { + return successfulAuthentication; + } + + if (!registeredClient.getClientId().equals(authorization.getRegisteredClientId())) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + + tokenRevocationService.revoke(token, TokenType.ACCESS_TOKEN); + return successfulAuthentication; + } + + @Override + public boolean supports(Class authentication) { + return OAuth2TokenRevocationAuthenticationToken.class.isAssignableFrom(authentication); + } +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java new file mode 100644 index 0000000..d42bbf8 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java @@ -0,0 +1,102 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.springframework.lang.Nullable; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.authorization.Version; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; + +import java.util.Collections; + +/** + * An {@link Authentication} implementation used for OAuth 2.0 Client Authentication. + * + * @author Vivek Babu + * @since 0.0.1 + * @see AbstractAuthenticationToken + * @see RegisteredClient + * @see OAuth2TokenRevocationAuthenticationProvider + */ +public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private final String tokenTypeHint; + private Authentication clientPrincipal; + private String token; + private RegisteredClient registeredClient; + + public OAuth2TokenRevocationAuthenticationToken(String token, + Authentication clientPrincipal, @Nullable String tokenTypeHint) { + super(Collections.emptyList()); + Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); + Assert.hasText(token, "token cannot be empty"); + this.token = token; + this.clientPrincipal = clientPrincipal; + this.tokenTypeHint = tokenTypeHint; + } + + public OAuth2TokenRevocationAuthenticationToken(String token, + RegisteredClient registeredClient, @Nullable String tokenTypeHint) { + super(Collections.emptyList()); + Assert.notNull(registeredClient, "registeredClient cannot be null"); + Assert.hasText(token, "token cannot be empty"); + this.token = token; + this.registeredClient = registeredClient; + this.tokenTypeHint = tokenTypeHint; + setAuthenticated(true); + } + + @Override + public Object getPrincipal() { + return this.clientPrincipal != null ? this.clientPrincipal : this.registeredClient + .getClientId(); + } + + @Override + public Object getCredentials() { + return ""; + } + + /** + * Returns the token. + * + * @return the token + */ + public String getToken() { + return this.token; + } + + /** + * Returns the token type hint. + * + * @return the token type hint + */ + public String getTokenTypeHint() { + return tokenTypeHint; + } + + /** + * Returns the {@link RegisteredClient registered client}. + * + * @return the {@link RegisteredClient} + */ + public @Nullable + RegisteredClient getRegisteredClient() { + return this.registeredClient; + } +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java new file mode 100644 index 0000000..a8ba267 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java @@ -0,0 +1,149 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.filter.OncePerRequestFilter; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; + +/** + * A {@code Filter} for the OAuth 2.0 Token Revocation, + * which handles the processing of the OAuth 2.0 Token Revocation Request. + * + * @author Vivek Babu + * @see OAuth2AuthorizationService + * @see OAuth2Authorization + * @see Section 2 Token Revocation + * @see Section 2.1 Revocation Request + * @since 0.0.1 + */ +public class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFilter { + + /** + * The default endpoint {@code URI} for token revocation request. + */ + public static final String DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI = "/oauth2/revoke"; + private static final String TOKEN_TYPE_HINT = "token_type_hint"; + private static final String TOKEN = "token"; + private final AntPathRequestMatcher revocationEndpointMatcher; + + private final Converter tokenRevocationAuthenticationConverter = + new OAuth2TokenRevocationEndpointFilter.TokenRevocationAuthenticationConverter(); + private final HttpMessageConverter errorHttpResponseConverter = + new OAuth2ErrorHttpMessageConverter(); + private final AuthenticationManager authenticationManager; + + /** + * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters. + * + * @param authenticationManager the authentication manager + */ + public OAuth2TokenRevocationEndpointFilter(AuthenticationManager authenticationManager) { + this(authenticationManager, DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI); + } + + /** + * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters. + * + * @param authenticationManager the authentication manager + * @param revocationEndpointUri the endpoint {@code URI} for revocation requests + */ + public OAuth2TokenRevocationEndpointFilter(AuthenticationManager authenticationManager, + String revocationEndpointUri) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + Assert.hasText(revocationEndpointUri, "revocationEndpointUri cannot be empty"); + this.authenticationManager = authenticationManager; + this.revocationEndpointMatcher = new AntPathRequestMatcher( + revocationEndpointUri, HttpMethod.POST.name()); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (!this.revocationEndpointMatcher.matches(request)) { + filterChain.doFilter(request, response); + return; + } + + try { + Authentication tokenRevocationRequestAuthentication = + this.tokenRevocationAuthenticationConverter.convert(request); + this.authenticationManager.authenticate(tokenRevocationRequestAuthentication); + } catch (OAuth2AuthenticationException ex) { + SecurityContextHolder.clearContext(); + sendErrorResponse(response, ex.getError()); + } + } + + private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + this.errorHttpResponseConverter.write(error, null, httpResponse); + } + + private static OAuth2AuthenticationException throwError(String errorCode, String parameterName) { + OAuth2Error error = new OAuth2Error(errorCode, "Token Revocation Request Parameter: " + parameterName, + "https://tools.ietf.org/html/rfc7009#section-2.1"); + throw new OAuth2AuthenticationException(error); + } + + private static class TokenRevocationAuthenticationConverter implements + Converter { + + @Override + public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + + Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + + // token (REQUIRED) + String token = parameters.getFirst(TOKEN); + if (!StringUtils.hasText(token) || + parameters.get(TOKEN).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, TOKEN); + } + + // token_type_hint (OPTIONAL) + String tokenTypeHint = parameters.getFirst(TOKEN_TYPE_HINT); + + return new OAuth2TokenRevocationAuthenticationToken(token, clientPrincipal, tokenTypeHint); + } + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java new file mode 100644 index 0000000..bd1fb77 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java @@ -0,0 +1,92 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultOAuth2TokenRevocationService}. + * + * @author Vivek Babu + */ +public class DefaultOAuth2TokenRevocationServiceTests { + private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); + private static final String PRINCIPAL_NAME = "principal"; + private static final String AUTHORIZATION_CODE = "code"; + private DefaultOAuth2TokenRevocationService revocationService; + private OAuth2AuthorizationService authorizationService; + + @Before + public void setup() { + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.revocationService = new DefaultOAuth2TokenRevocationService(authorizationService); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultOAuth2TokenRevocationService(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); + } + + @Test + public void revokeWhenTokenNotFound() { + this.revocationService.revoke("token", TokenType.ACCESS_TOKEN); + verify(authorizationService, times(1)).findByTokenAndTokenType(eq("token"), + eq(TokenType.ACCESS_TOKEN)); + verify(authorizationService, times(0)).save(any()); + } + + @Test + public void revokeWhenTokenFound() { + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "token", Instant.now().minusSeconds(60), Instant.now()); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .principalName(PRINCIPAL_NAME) + .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .accessToken(accessToken) + .build(); + when(authorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + this.revocationService.revoke("token", TokenType.ACCESS_TOKEN); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + final OAuth2Authorization savedAuthorization = authorizationCaptor.getValue(); + assertThat(savedAuthorization.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); + assertThat((String) savedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)) + .isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); + assertThat(savedAuthorization.getAccessToken()).isEqualTo(authorization.getAccessToken()); + assertThat(savedAuthorization.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); + assertThat(savedAuthorization.isRevoked()).isTrue(); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java new file mode 100644 index 0000000..bec949f --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java @@ -0,0 +1,150 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenRevocationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2TokenRevocationAuthenticationProvider}. + * + * @author Vivek Babu + */ +public class OAuth2TokenRevocationAuthenticationProviderTests { + private RegisteredClient registeredClient; + private OAuth2AuthorizationService oAuth2AuthorizationService; + private OAuth2TokenRevocationService oAuth2TokenRevocationService; + private OAuth2TokenRevocationAuthenticationProvider authenticationProvider; + + @Before + public void setUp() { + this.registeredClient = TestRegisteredClients.registeredClient().build(); + this.oAuth2AuthorizationService = mock(OAuth2AuthorizationService.class); + this.oAuth2TokenRevocationService = mock(OAuth2TokenRevocationService.class); + this.authenticationProvider = new OAuth2TokenRevocationAuthenticationProvider(oAuth2AuthorizationService, + oAuth2TokenRevocationService); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(null, + oAuth2TokenRevocationService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); + } + + @Test + public void constructorWhenRevocationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(oAuth2AuthorizationService, + null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("tokenRevocationService cannot be null"); + } + + @Test + public void supportsWhenTypeOAuth2TokenRevocationAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2TokenRevocationAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", clientPrincipal, "access_token"); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", clientPrincipal, "access_token"); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenInvalidTokenThenAuthenticate() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", clientPrincipal, "access_token"); + OAuth2TokenRevocationAuthenticationToken authenticationResult = + (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(this.registeredClient.getClientId()); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(this.registeredClient); + } + + @Test + public void authenticateWhenAuthorizationIssuedToAnotherClientThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + when(this.oAuth2AuthorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + TestRegisteredClients.registeredClient2().build()); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", clientPrincipal, "access_token"); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenValidAccessTokenThenInvalidateTokenAndAuthenticate() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", clientPrincipal, "access_token"); + OAuth2Authorization mockAuthorization = mock(OAuth2Authorization.class); + when(oAuth2AuthorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN))). + thenReturn(mockAuthorization); + when(mockAuthorization.getRegisteredClientId()).thenReturn(this.registeredClient.getClientId()); + OAuth2TokenRevocationAuthenticationToken authenticationResult = + (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + verify(this.oAuth2TokenRevocationService).revoke(eq("token"), eq(TokenType.ACCESS_TOKEN)); + + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(this.registeredClient.getClientId()); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(this.registeredClient); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java new file mode 100644 index 0000000..f34cfc0 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.junit.Test; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2TokenRevocationAuthenticationToken}. + * + * @author Vivek Babu + */ +public class OAuth2TokenRevocationAuthenticationTokenTests { + private OAuth2TokenRevocationAuthenticationToken clientPrincipal = new OAuth2TokenRevocationAuthenticationToken( + "Token", TestRegisteredClients.registeredClient().build(), null); + private RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + + @Test + public void constructorWhenTokenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, + this.clientPrincipal, "hint")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be empty"); + } + + @Test + public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken("token", + (Authentication) null, "hint")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientPrincipal cannot be null"); + } + + @Test + public void constructorWhenTokenNullRegisteredClientPresentThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, registeredClient, "hint")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be empty"); + } + + @Test + public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken("token", + (RegisteredClient) null, "hint")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClient cannot be null"); + } + + @Test + public void constructorWhenTokenAndClientPrincipalProvidedThenCreated() { + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", this.clientPrincipal, "token_hint"); + assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getToken()).isEqualTo("token"); + assertThat(authentication.getTokenTypeHint()).isEqualTo("token_hint"); + assertThat(authentication.isAuthenticated()).isFalse(); + } + + @Test + public void constructorWhenTokenAndRegisteredProvidedThenCreated() { + OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( + "token", this.registeredClient, "token_hint"); + assertThat(authentication.getPrincipal()).isEqualTo(this.registeredClient.getClientId()); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getToken()).isEqualTo("token"); + assertThat(authentication.getTokenTypeHint()).isEqualTo("token_hint"); + assertThat(authentication.isAuthenticated()).isTrue(); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java new file mode 100644 index 0000000..101feb8 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java @@ -0,0 +1,196 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.function.Consumer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2TokenRevocationEndpointFilter}. + * + * @author Vivek Babu + */ +public class OAuth2TokenRevocationEndpointFilterTests { + private static final String TOKEN = "token"; + private static final String TOKEN_TYPE_HINT = "token_type_hint"; + private AuthenticationManager authenticationManager; + private OAuth2TokenRevocationEndpointFilter filter; + private final HttpMessageConverter errorHttpResponseConverter = + new OAuth2ErrorHttpMessageConverter(); + + @Before + public void setUp() { + this.authenticationManager = mock(AuthenticationManager.class); + this.filter = new OAuth2TokenRevocationEndpointFilter(this.authenticationManager); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + } + + @Test + public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationEndpointFilter(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationManager cannot be null"); + } + + @Test + public void constructorWhenTokenEndpointUriNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenRevocationEndpointFilter(this.authenticationManager, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("revocationEndpointUri cannot be empty"); + } + + @Test + public void doFilterWhenNotRevocationRequestThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenRevocationRequestGetThenNotProcessed() throws Exception { + String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenRevocationRequestMissingTokenThenInvalidRequestError() throws Exception { + doFilterWhenRevocationRequestInvalidParameterThenError( + TOKEN, OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(TOKEN)); + } + + @Test + public void doFilterWhenRevocationRequestMultipleTokenThenInvalidRequestError() throws Exception { + doFilterWhenRevocationRequestInvalidParameterThenError( + TOKEN, OAuth2ErrorCodes.INVALID_REQUEST, + request -> { + request.addParameter(TOKEN, "token-1"); + request.addParameter(TOKEN, "token-2"); + }); + } + + @Test + public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + + Authentication tokenRevocationAuthenticationSuccess = mock(Authentication.class); + + when(this.authenticationManager.authenticate(any())).thenReturn(tokenRevocationAuthenticationSuccess); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + + MockHttpServletRequest request = createRevocationRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + ArgumentCaptor tokenRevocationAuthenticationCaptor = + ArgumentCaptor.forClass(OAuth2TokenRevocationAuthenticationToken.class); + verify(this.authenticationManager).authenticate(tokenRevocationAuthenticationCaptor.capture()); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + } + + private void doFilterWhenRevocationRequestInvalidParameterThenError(String parameterName, String errorCode, + Consumer requestConsumer) throws Exception { + + MockHttpServletRequest request = createRevocationRequest(); + requestConsumer.accept(request); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(errorCode); + assertThat(error.getDescription()).isEqualTo("Token Revocation Request Parameter: " + parameterName); + } + + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); + } + + private static MockHttpServletRequest createRevocationRequest() { + + String requestUri = OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter(TOKEN, "token"); + request.addParameter(TOKEN_TYPE_HINT, "access_token"); + + return request; + } +}