From ea828fb2bfba931696113883e6ea684e9fa48f7b Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 10 Jul 2020 09:16:19 -0400 Subject: [PATCH] Polish gh-88 --- .../OAuth2AuthorizationServerConfigurer.java | 21 ++- .../authorization/OAuth2Authorization.java | 1 - ...ientCredentialsAuthenticationProvider.java | 68 +++++--- ...2ClientCredentialsAuthenticationToken.java | 50 ++++-- ...orizationGrantAuthenticationConverter.java | 31 ++-- .../web/OAuth2TokenEndpointFilter.java | 47 +++-- .../OAuth2AuthorizationTests.java | 9 - ...redentialsAuthenticationProviderTests.java | 136 +++++++++------ ...ntCredentialsAuthenticationTokenTests.java | 24 ++- .../client/TestRegisteredClients.java | 5 +- ...tionGrantAuthenticationConverterTests.java | 98 ++++++----- .../OAuth2ClientCredentialsGrantTests.java | 70 +++++--- .../web/OAuth2TokenEndpointFilterTests.java | 165 ++++++++++-------- 13 files changed, 430 insertions(+), 295 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index ef3ac5f..1e01e87 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -19,9 +19,11 @@ import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; +import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer; import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider; @@ -32,6 +34,7 @@ import org.springframework.security.oauth2.server.authorization.web.OAuth2Author import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientAuthenticationFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; +import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.util.Assert; @@ -89,12 +92,22 @@ public final class OAuth2AuthorizationServerConfigurer exceptionHandling = builder.getConfigurer(ExceptionHandlingConfigurer.class); + if (exceptionHandling != null) { + // Register the default AuthenticationEntryPoint for the token endpoint + exceptionHandling.defaultAuthenticationEntryPointFor( + new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED), + new AntPathRequestMatcher( + OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI, + HttpMethod.POST.name())); + } } @Override diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index f1c0abf..bc05a07 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -210,7 +210,6 @@ public class OAuth2Authorization implements Serializable { */ public OAuth2Authorization build() { Assert.hasText(this.principalName, "principalName cannot be empty"); - Assert.notNull(this.attributes.get(OAuth2AuthorizationAttributeNames.CODE), "authorization code cannot be null"); OAuth2Authorization authorization = new OAuth2Authorization(); authorization.registeredClientId = this.registeredClientId; diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index 4037717..29d0fc2 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -15,11 +15,6 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Base64; -import java.util.Set; - import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -29,6 +24,18 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.LinkedHashSet; +import java.util.Set; +import java.util.stream.Collectors; /** * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Client Credentials Grant. @@ -36,46 +43,63 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes; * @author Alexey Nesterov * @since 0.0.1 * @see OAuth2ClientCredentialsAuthenticationToken + * @see OAuth2AccessTokenAuthenticationToken + * @see OAuth2AuthorizationService * @see Section 4.4 Client Credentials Grant * @see Section 4.4.2 Access Token Request */ - public class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider { - + private final OAuth2AuthorizationService authorizationService; private final StringKeyGenerator accessTokenGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + /** + * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters. + * + * @param authorizationService the authorization service + */ + public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService authorizationService) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.authorizationService = authorizationService; + } + @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthenticationToken = + OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication = (OAuth2ClientCredentialsAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientPrincipal = null; - if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientCredentialsAuthenticationToken.getPrincipal().getClass())) { - clientPrincipal = (OAuth2ClientAuthenticationToken) clientCredentialsAuthenticationToken.getPrincipal(); + if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientCredentialsAuthentication.getPrincipal().getClass())) { + clientPrincipal = (OAuth2ClientAuthenticationToken) clientCredentialsAuthentication.getPrincipal(); } - if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } + RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); - Set clientScopes = clientPrincipal.getRegisteredClient().getScopes(); - Set requestedScopes = clientCredentialsAuthenticationToken.getScopes(); - if (!clientScopes.containsAll(requestedScopes)) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE)); - } - - if (requestedScopes == null || requestedScopes.isEmpty()) { - requestedScopes = clientScopes; + Set scopes = registeredClient.getScopes(); // Default to configured scopes + if (!CollectionUtils.isEmpty(clientCredentialsAuthentication.getScopes())) { + Set unauthorizedScopes = clientCredentialsAuthentication.getScopes().stream() + .filter(requestedScope -> !registeredClient.getScopes().contains(requestedScope)) + .collect(Collectors.toSet()); + if (!CollectionUtils.isEmpty(unauthorizedScopes)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE)); + } + scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes()); } String tokenValue = this.accessTokenGenerator.generateKey(); Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token lifespan OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - tokenValue, issuedAt, expiresAt, requestedScopes); + tokenValue, issuedAt, expiresAt, scopes); - return new OAuth2AccessTokenAuthenticationToken( - clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) + .principalName(clientPrincipal.getName()) + .accessToken(accessToken) + .build(); + this.authorizationService.save(authorization); + + return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken); } @Override diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java index 575cb0d..7b27ea2 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java @@ -13,48 +13,52 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.authentication; -import java.util.Collections; -import java.util.Set; - import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.util.Assert; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + /** * An {@link Authentication} implementation used for the OAuth 2.0 Client Credentials Grant. * * @author Alexey Nesterov * @since 0.0.1 - * @see Authentication + * @see AbstractAuthenticationToken * @see OAuth2ClientCredentialsAuthenticationProvider + * @see OAuth2ClientAuthenticationToken */ public class OAuth2ClientCredentialsAuthenticationToken extends AbstractAuthenticationToken { - private static final long serialVersionUID = Version.SERIAL_VERSION_UID; - private final Authentication clientPrincipal; private final Set scopes; + /** + * Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters. + * + * @param clientPrincipal the authenticated client principal + */ + public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal) { + this(clientPrincipal, Collections.emptySet()); + } + + /** + * Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters. + * + * @param clientPrincipal the authenticated client principal + * @param scopes the requested scope(s) + */ public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal, Set scopes) { super(Collections.emptyList()); Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); Assert.notNull(scopes, "scopes cannot be null"); this.clientPrincipal = clientPrincipal; - this.scopes = scopes; - } - - @SuppressWarnings("unchecked") - public OAuth2ClientCredentialsAuthenticationToken(OAuth2ClientAuthenticationToken clientPrincipal) { - this(clientPrincipal, Collections.EMPTY_SET); - } - - @Override - public Object getCredentials() { - return ""; + this.scopes = Collections.unmodifiableSet(new LinkedHashSet<>(scopes)); } @Override @@ -62,6 +66,16 @@ public class OAuth2ClientCredentialsAuthenticationToken extends AbstractAuthenti return this.clientPrincipal; } + @Override + public Object getCredentials() { + return ""; + } + + /** + * Returns the requested scope(s). + * + * @return the requested scope(s), or an empty {@code Set} if not available + */ public Set getScopes() { return this.scopes; } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverter.java index 30cf0d8..ddd70f3 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverter.java @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.web; -import javax.servlet.http.HttpServletRequest; -import java.util.Collections; -import java.util.Map; - import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -27,31 +22,43 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import javax.servlet.http.HttpServletRequest; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + /** - * A {@link Converter} that delegates actual conversion to one of the provided converters based on grant_type param of a request. - * Returns null is grant type is not specified or not supported. + * A {@link Converter} that selects (and delegates) to one of the internal {@code Map} of {@link Converter}'s + * using the {@link OAuth2ParameterNames#GRANT_TYPE} request parameter. * * @author Alexey Nesterov * @since 0.0.1 */ public final class DelegatingAuthorizationGrantAuthenticationConverter implements Converter { - private final Map> converters; - public DelegatingAuthorizationGrantAuthenticationConverter(Map> converters) { + /** + * Constructs a {@code DelegatingAuthorizationGrantAuthenticationConverter} using the provided parameters. + * + * @param converters a {@code Map} of {@link Converter}(s) + */ + public DelegatingAuthorizationGrantAuthenticationConverter( + Map> converters) { Assert.notEmpty(converters, "converters cannot be empty"); - - this.converters = Collections.unmodifiableMap(converters); + this.converters = Collections.unmodifiableMap(new HashMap<>(converters)); } @Override public Authentication convert(HttpServletRequest request) { + Assert.notNull(request, "request cannot be null"); + String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); if (StringUtils.isEmpty(grantType)) { return null; } - Converter converter = this.converters.get(new AuthorizationGrantType(grantType)); + Converter converter = + this.converters.get(new AuthorizationGrantType(grantType)); if (converter == null) { return null; } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 7356f65..3680d66 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -35,7 +35,6 @@ import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMe import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; -import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -94,7 +93,6 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { private final OAuth2AuthorizationService authorizationService; private final RequestMatcher tokenEndpointMatcher; private final Converter authorizationGrantAuthenticationConverter; - private final HttpMessageConverter accessTokenHttpResponseConverter = new OAuth2AccessTokenResponseHttpMessageConverter(); private final HttpMessageConverter errorHttpResponseConverter = @@ -126,7 +124,6 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { this.authenticationManager = authenticationManager; this.authorizationService = authorizationService; this.tokenEndpointMatcher = new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name()); - Map> converters = new HashMap<>(); converters.put(AuthorizationGrantType.AUTHORIZATION_CODE, new AuthorizationCodeAuthenticationConverter()); converters.put(AuthorizationGrantType.CLIENT_CREDENTIALS, new ClientCredentialsAuthenticationConverter()); @@ -144,18 +141,19 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { try { String[] grantTypes = request.getParameterValues(OAuth2ParameterNames.GRANT_TYPE); - if (grantTypes == null || grantTypes.length == 0) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, "grant_type"); + if (grantTypes == null || grantTypes.length != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE); } Authentication authorizationGrantAuthentication = this.authorizationGrantAuthenticationConverter.convert(request); if (authorizationGrantAuthentication == null) { - throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, "grant_type"); + throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE); } OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication); sendAccessTokenResponse(response, accessTokenAuthentication.getAccessToken()); + } catch (OAuth2AuthenticationException ex) { SecurityContextHolder.clearContext(); sendErrorResponse(response, ex.getError()); @@ -191,18 +189,14 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { @Override public Authentication convert(HttpServletRequest request) { - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // grant_type (REQUIRED) - String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); - if (!StringUtils.hasText(grantType) || - parameters.get(OAuth2ParameterNames.GRANT_TYPE).size() != 1) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE); - } + String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) { - throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE); + return null; } + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + // client_id (REQUIRED) String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); Authentication clientPrincipal = null; @@ -239,24 +233,29 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { @Override public Authentication convert(HttpServletRequest request) { - final Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - final OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) authentication; - // grant_type (REQUIRED) String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)) { - throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE); + return null; } + Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + // scope (OPTIONAL) - // https://tools.ietf.org/html/rfc6749#section-4.4.2 - String scopeParameter = request.getParameter(OAuth2ParameterNames.SCOPE); - if (StringUtils.isEmpty(scopeParameter)) { - return new OAuth2ClientCredentialsAuthenticationToken(clientAuthenticationToken); + String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope) && + parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE); + } + if (StringUtils.hasText(scope)) { + Set requestedScopes = new HashSet<>( + Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); + return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScopes); } - Set requestedScopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scopeParameter, " "))); - return new OAuth2ClientCredentialsAuthenticationToken(clientAuthenticationToken, requestedScopes); + return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); } } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index d3daef0..6d9a041 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -75,15 +75,6 @@ public class OAuth2AuthorizationTests { .hasMessage("principalName cannot be empty"); } - @Test - public void buildWhenAuthorizationCodeNotProvidedThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) - .principalName(PRINCIPAL_NAME).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorization code cannot be null"); - } - @Test public void attributeWhenNameNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index 621d027..34990f4 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -13,104 +13,130 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.authentication; -import java.util.Collections; - import org.junit.Before; import org.junit.Test; - -import org.springframework.security.core.Authentication; +import org.mockito.ArgumentCaptor; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import java.util.Collections; +import java.util.Set; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** + * Tests for {@link OAuth2ClientCredentialsAuthenticationProvider}. + * * @author Alexey Nesterov + * @author Joe Grandja */ public class OAuth2ClientCredentialsAuthenticationProviderTests { - - private static final RegisteredClient EXISTING_CLIENT = TestRegisteredClients.registeredClient().build(); + private RegisteredClient registeredClient; + private OAuth2AuthorizationService authorizationService; private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider; @Before public void setUp() { - this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider(); + this.registeredClient = TestRegisteredClients.registeredClient().build(); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService); } @Test - public void supportsWhenSupportedClassThenTrue() { + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); + } + + @Test + public void supportsWhenSupportedAuthenticationThenTrue() { assertThat(this.authenticationProvider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isTrue(); } @Test - public void supportsWhenUnsupportedClassThenFalse() { - assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationProvider.class)).isFalse(); + public void supportsWhenUnsupportedAuthenticationThenFalse() { + assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isFalse(); } @Test - public void authenticateWhenValidAuthenticationThenReturnTokenWithClient() { - Authentication authentication = this.authenticationProvider.authenticate(getAuthentication()); - assertThat(authentication).isInstanceOf(OAuth2AccessTokenAuthenticationToken.class); - - OAuth2AccessTokenAuthenticationToken token = (OAuth2AccessTokenAuthenticationToken) authentication; - assertThat(token.getRegisteredClient()).isEqualTo(EXISTING_CLIENT); - } - - @Test - public void authenticateWhenValidAuthenticationThenGenerateTokenValue() { - Authentication authentication = this.authenticationProvider.authenticate(getAuthentication()); - OAuth2AccessTokenAuthenticationToken token = (OAuth2AccessTokenAuthenticationToken) authentication; - assertThat(token.getAccessToken().getTokenValue()).isNotBlank(); - } - - @Test - public void authenticateWhenValidateScopeThenReturnTokenWithScopes() { - Authentication authentication = this.authenticationProvider.authenticate(getAuthentication()); - OAuth2AccessTokenAuthenticationToken token = (OAuth2AccessTokenAuthenticationToken) authentication; - assertThat(token.getAccessToken().getScopes()).containsAll(EXISTING_CLIENT.getScopes()); - } - - @Test - public void authenticateWhenNoScopeRequestedThenUseDefaultScopes() { - OAuth2ClientCredentialsAuthenticationToken authenticationToken = new OAuth2ClientCredentialsAuthenticationToken(new OAuth2ClientAuthenticationToken(EXISTING_CLIENT)); - Authentication authentication = this.authenticationProvider.authenticate(authenticationToken); - OAuth2AccessTokenAuthenticationToken token = (OAuth2AccessTokenAuthenticationToken) authentication; - assertThat(token.getAccessToken().getScopes()).containsAll(EXISTING_CLIENT.getScopes()); - } - - @Test - public void authenticateWhenInvalidSecretThenThrowException() { - OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( - new OAuth2ClientAuthenticationToken(EXISTING_CLIENT.getClientId(), "not-a-valid-secret")); + public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class); + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test - public void authenticateWhenNonExistingClientThenThrowException() { - OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( - new OAuth2ClientAuthenticationToken("another-client-id", "another-secret")); + public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class); + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test - public void authenticateWhenInvalidScopesThenThrowException() { + public void authenticateWhenInvalidScopeThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( - new OAuth2ClientAuthenticationToken(EXISTING_CLIENT), Collections.singleton("non-existing-scope")); + clientPrincipal, Collections.singleton("invalid-scope")); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class); + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_SCOPE); } - private OAuth2ClientCredentialsAuthenticationToken getAuthentication() { - return new OAuth2ClientCredentialsAuthenticationToken(new OAuth2ClientAuthenticationToken(EXISTING_CLIENT), EXISTING_CLIENT.getScopes()); + @Test + public void authenticateWhenScopeRequestedThenAccessTokenContainsScope() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + Set requestedScope = Collections.singleton("openid"); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(requestedScope); + } + + @Test + public void authenticateWhenValidAuthenticationThenReturnAccessToken() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization authorization = authorizationCaptor.getValue(); + + assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId()); + assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName()); + assertThat(authorization.getAccessToken()).isNotNull(); + assertThat(authorization.getAccessToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes()); + assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getAccessToken()); } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java index c54e872..1df80ea 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java @@ -13,24 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.authentication; +import org.junit.Test; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + import java.util.Collections; import java.util.Set; -import org.junit.Test; - -import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** + * Tests for {@link OAuth2ClientCredentialsAuthenticationToken}. + * * @author Alexey Nesterov */ public class OAuth2ClientCredentialsAuthenticationTokenTests { - private final OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); @@ -43,15 +42,15 @@ public class OAuth2ClientCredentialsAuthenticationTokenTests { @Test public void constructorWhenScopesNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null)) + assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("scopes cannot be null"); } @Test public void constructorWhenClientPrincipalProvidedThenCreated() { - OAuth2ClientCredentialsAuthenticationToken authentication - = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getCredentials().toString()).isEmpty(); @@ -62,12 +61,11 @@ public class OAuth2ClientCredentialsAuthenticationTokenTests { public void constructorWhenScopesProvidedThenCreated() { Set expectedScopes = Collections.singleton("test-scope"); - OAuth2ClientCredentialsAuthenticationToken authentication - = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, expectedScopes); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, expectedScopes); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getCredentials().toString()).isEmpty(); - assertThat(authentication.getScopes()).containsAll(expectedScopes); + assertThat(authentication.getScopes()).isEqualTo(expectedScopes); } - } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java index 5aa7bc0..e6421d4 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -40,10 +40,13 @@ public class TestRegisteredClients { .clientId("client-2") .clientSecret("secret") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .redirectUri("https://example.com") .scope("openid") .scope("profile") - .scope("email"); + .scope("email") + .scope("scope1") + .scope("scope2"); } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverterTests.java index b8326f4..dd47ad5 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverterTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DelegatingAuthorizationGrantAuthenticationConverterTests.java @@ -13,84 +13,102 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.web; -import javax.servlet.http.HttpServletRequest; -import java.util.Collections; -import java.util.Map; - import org.junit.Before; import org.junit.Test; - import org.springframework.core.convert.converter.Converter; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; +import javax.servlet.http.HttpServletRequest; +import java.util.Collections; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** + * Tests for {@link DelegatingAuthorizationGrantAuthenticationConverter}. + * * @author Alexey Nesterov */ public class DelegatingAuthorizationGrantAuthenticationConverterTests { - + private Converter clientCredentialsAuthenticationConverter; private DelegatingAuthorizationGrantAuthenticationConverter authenticationConverter; - private Converter clientCredentialsConverterMock; @Before public void setUp() { - clientCredentialsConverterMock = mock(Converter.class); - Map> converters - = Collections.singletonMap(AuthorizationGrantType.CLIENT_CREDENTIALS, clientCredentialsConverterMock); - authenticationConverter = new DelegatingAuthorizationGrantAuthenticationConverter(converters); + this.clientCredentialsAuthenticationConverter = mock(Converter.class); + Map> converters = + Collections.singletonMap(AuthorizationGrantType.CLIENT_CREDENTIALS, this.clientCredentialsAuthenticationConverter); + this.authenticationConverter = new DelegatingAuthorizationGrantAuthenticationConverter(converters); } @Test - public void convertWhenAuthorizationGrantTypeSupportedThenConverterCalled() { - MockHttpServletRequest request = MockMvcRequestBuilders - .post("/oauth/token") - .param("grant_type", "client_credentials") - .buildRequest(new MockServletContext()); - - OAuth2ClientAuthenticationToken expectedAuthentication = new OAuth2ClientAuthenticationToken("id", "secret"); - when(clientCredentialsConverterMock.convert(request)).thenReturn(expectedAuthentication); - - Authentication actualAuthentication = authenticationConverter.convert(request); - - verify(clientCredentialsConverterMock).convert(request); - assertThat(actualAuthentication).isEqualTo(expectedAuthentication); + public void constructorWhenConvertersEmptyThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DelegatingAuthorizationGrantAuthenticationConverter(Collections.emptyMap())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("converters cannot be empty"); } @Test - public void convertWhenAuthorizationGrantTypeNotSupportedThenNull() { - MockHttpServletRequest request = MockMvcRequestBuilders - .post("/oauth/token") - .param("grant_type", "authorization_code") - .buildRequest(new MockServletContext()); - - Authentication actualAuthentication = authenticationConverter.convert(request); - - verifyNoInteractions(clientCredentialsConverterMock); - assertThat(actualAuthentication).isNull(); + public void convertWhenRequestNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationConverter.convert(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("request cannot be null"); } @Test - public void convertWhenNoAuthorizationGrantTypeThenNull() { + public void convertWhenGrantTypeMissingThenNull() { MockHttpServletRequest request = MockMvcRequestBuilders - .post("/oauth/token") + .post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) .buildRequest(new MockServletContext()); - Authentication actualAuthentication = authenticationConverter.convert(request); + Authentication authentication = this.authenticationConverter.convert(request); + assertThat(authentication).isNull(); + verifyNoInteractions(this.clientCredentialsAuthenticationConverter); + } - verifyNoInteractions(clientCredentialsConverterMock); - assertThat(actualAuthentication).isNull(); + @Test + public void convertWhenGrantTypeUnsupportedThenNull() { + MockHttpServletRequest request = MockMvcRequestBuilders + .post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + .param(OAuth2ParameterNames.GRANT_TYPE, "extension_grant_type") + .buildRequest(new MockServletContext()); + + Authentication authentication = this.authenticationConverter.convert(request); + assertThat(authentication).isNull(); + verifyNoInteractions(this.clientCredentialsAuthenticationConverter); + } + + @Test + public void convertWhenGrantTypeSupportedThenConverterCalled() { + OAuth2ClientCredentialsAuthenticationToken expectedAuthentication = + new OAuth2ClientCredentialsAuthenticationToken( + new OAuth2ClientAuthenticationToken( + TestRegisteredClients.registeredClient().build())); + when(this.clientCredentialsAuthenticationConverter.convert(any())).thenReturn(expectedAuthentication); + + MockHttpServletRequest request = MockMvcRequestBuilders + .post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .buildRequest(new MockServletContext()); + + Authentication authentication = this.authenticationConverter.convert(request); + assertThat(authentication).isEqualTo(expectedAuthentication); + verify(this.clientCredentialsAuthenticationConverter).convert(request); } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientCredentialsGrantTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientCredentialsGrantTests.java index 0249599..3cf938f 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientCredentialsGrantTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientCredentialsGrantTests.java @@ -13,42 +13,49 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.web; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.http.HttpHeaders; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.oauth2.server.authorization.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.test.web.servlet.MockMvc; -import static org.hamcrest.CoreMatchers.endsWith; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** + * Integration tests for the OAuth 2.0 Client Credentials Grant. + * * @author Alexey Nesterov */ public class OAuth2ClientCredentialsGrantTests { - private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; @@ -71,35 +78,46 @@ public class OAuth2ClientCredentialsGrantTests { } @Test - public void requestWhenTokenRequestAuthenticatedThenThenReturnTokenAndScope() throws Exception { + public void requestWhenTokenRequestNotAuthenticatedThenUnauthorized() throws Exception { this.spring.register(AuthorizationServerConfiguration.class).autowire(); - RegisteredClient client = TestRegisteredClients.registeredClient().build(); - when(registeredClientRepository.findByClientId(client.getClientId())) - .thenReturn(client); this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) - .with(httpBasic(client.getClientId(), client.getClientSecret())) - .with(csrf()) - .param("grant_type", "client_credentials") - .param("scope", "email openid")) - .andExpect(status().isOk()) - .andExpect(jsonPath("$.access_token").isNotEmpty()) - .andExpect(jsonPath("$.scope").value("openid email")); + .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .with(csrf())) + .andExpect(status().isUnauthorized()); + + verifyNoInteractions(registeredClientRepository); + verifyNoInteractions(authorizationService); } @Test - public void requestWhenTokenRequestNotAuthenticatedThenRedirect() throws Exception { + public void requestWhenTokenRequestValidThenTokenResponse() throws Exception { this.spring.register(AuthorizationServerConfiguration.class).autowire(); - RegisteredClient client = TestRegisteredClients.registeredClient().build(); - when(registeredClientRepository.findByClientId(client.getClientId())) - .thenReturn(client); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) - .with(csrf()) - .param("grant_type", "client_credentials") - .param("scope", "email openid")) - .andExpect(status().isFound()) - .andExpect(header().string("Location", endsWith("/login"))); + .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .param(OAuth2ParameterNames.SCOPE, "scope1 scope2") + .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( + registeredClient.getClientId(), registeredClient.getClientSecret())) + .with(csrf())) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.access_token").isNotEmpty()) + .andExpect(jsonPath("$.scope").value("scope1 scope2")); + + verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); + verify(authorizationService).save(any()); + } + + private static String encodeBasicAuth(String clientId, String secret) throws Exception { + clientId = URLEncoder.encode(clientId, StandardCharsets.UTF_8.name()); + secret = URLEncoder.encode(secret, StandardCharsets.UTF_8.name()); + String credentialsString = clientId + ":" + secret; + byte[] encodedBytes = Base64.getEncoder().encode(credentialsString.getBytes(StandardCharsets.UTF_8)); + return new String(encodedBytes, StandardCharsets.UTF_8); } @EnableWebSecurity diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index 7c9c440..0ecd147 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -19,7 +19,6 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; - import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.mock.http.client.MockClientHttpResponse; @@ -44,17 +43,15 @@ import org.springframework.security.oauth2.server.authorization.authentication.O import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.util.StringUtils; import javax.servlet.FilterChain; -import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.IOException; import java.time.Duration; import java.time.Instant; import java.util.Arrays; import java.util.HashSet; -import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -140,58 +137,77 @@ public class OAuth2TokenEndpointFilterTests { @Test public void doFilterWhenTokenRequestMissingGrantTypeThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.removeParameter(OAuth2ParameterNames.GRANT_TYPE); + doFilterWhenTokenRequestInvalidParameterThenError( - OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.removeParameter(OAuth2ParameterNames.GRANT_TYPE)); + OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, request); } @Test public void doFilterWhenTokenRequestMultipleGrantTypeThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + doFilterWhenTokenRequestInvalidParameterThenError( - OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue())); + OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, request); } @Test public void doFilterWhenTokenRequestInvalidGrantTypeThenUnsupportedGrantTypeError() throws Exception { + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.setParameter(OAuth2ParameterNames.GRANT_TYPE, "invalid-grant-type"); + doFilterWhenTokenRequestInvalidParameterThenError( - OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, - request -> request.setParameter(OAuth2ParameterNames.GRANT_TYPE, "invalid-grant-type")); + OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, request); } @Test public void doFilterWhenTokenRequestMultipleClientIdThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1"); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"); + doFilterWhenTokenRequestInvalidParameterThenError( - OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, - request -> { - request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1"); - request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"); - }); + OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, request); } @Test public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.removeParameter(OAuth2ParameterNames.CODE); + doFilterWhenTokenRequestInvalidParameterThenError( - OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.removeParameter(OAuth2ParameterNames.CODE)); + OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, request); } @Test public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(OAuth2ParameterNames.CODE, "code-2"); + doFilterWhenTokenRequestInvalidParameterThenError( - OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.CODE, "code-2")); + OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, request); } @Test public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createAuthorizationCodeTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); + doFilterWhenTokenRequestInvalidParameterThenError( - OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com")); + OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, request); } @Test - public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception { + public void doFilterWhenAuthorizationCodeTokenRequestValidThenAccessTokenResponse() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AccessToken accessToken = new OAuth2AccessToken( @@ -208,7 +224,7 @@ public class OAuth2TokenEndpointFilterTests { securityContext.setAuthentication(clientPrincipal); SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createTokenRequest(registeredClient); + MockHttpServletRequest request = createAuthorizationCodeTokenRequest(registeredClient); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -242,38 +258,24 @@ public class OAuth2TokenEndpointFilterTests { } @Test - public void doFilterWhenGrantTypeIsClientCredentialsThenAuthenticateWithClientCredentialsToken() throws ServletException, IOException { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - doFilterForClientCredentialsGrant(registeredClient, null); + public void doFilterWhenTokenRequestMultipleScopeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); - ArgumentCaptor captor = ArgumentCaptor.forClass(Authentication.class); - verify(this.authenticationManager).authenticate(captor.capture()); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); - assertThat(captor.getValue()).isInstanceOf(OAuth2ClientCredentialsAuthenticationToken.class); - OAuth2ClientCredentialsAuthenticationToken clientAuthenticationToken = (OAuth2ClientCredentialsAuthenticationToken) captor.getValue(); + MockHttpServletRequest request = createClientCredentialsTokenRequest(registeredClient); + request.addParameter(OAuth2ParameterNames.SCOPE, "profile"); - assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(new OAuth2ClientAuthenticationToken(registeredClient)); + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.SCOPE, OAuth2ErrorCodes.INVALID_REQUEST, request); } @Test - public void doFilterWhenGrantTypeIsClientCredentialsWithScopeThenIncludeScopeInResponse() throws ServletException, IOException { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - doFilterForClientCredentialsGrant(registeredClient, "openid email"); - - ArgumentCaptor captor = ArgumentCaptor.forClass(Authentication.class); - verify(this.authenticationManager).authenticate(captor.capture()); - - assertThat(captor.getValue()).isInstanceOf(OAuth2ClientCredentialsAuthenticationToken.class); - OAuth2ClientCredentialsAuthenticationToken clientAuthenticationToken = (OAuth2ClientCredentialsAuthenticationToken) captor.getValue(); - - HashSet expectedScopes = new HashSet<>(); - expectedScopes.add("openid"); - expectedScopes.add("email"); - - assertThat(clientAuthenticationToken.getScopes()).isEqualTo(expectedScopes); - } - - private void doFilterForClientCredentialsGrant(RegisteredClient registeredClient, String scope) throws ServletException, IOException { + public void doFilterWhenClientCredentialsTokenRequestValidThenAccessTokenResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "token", @@ -282,35 +284,46 @@ public class OAuth2TokenEndpointFilterTests { OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = new OAuth2AccessTokenAuthenticationToken( registeredClient, clientPrincipal, accessToken); - final String clientId = registeredClient.getClientId(); - final String clientSecret = registeredClient.getClientSecret(); - - MockHttpServletRequest request = new MockHttpServletRequest("POST", OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI); - request.setServletPath(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI); - request.addParameter("client_id", clientId); - request.addParameter("client_secret", clientSecret); - request.addParameter("grant_type", AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); - if (scope != null) { - request.addParameter("scope", scope); - } when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication); - SecurityContext context = SecurityContextHolder.createEmptyContext(); - context.setAuthentication(new OAuth2ClientAuthenticationToken(registeredClient)); - SecurityContextHolder.setContext(context); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + MockHttpServletRequest request = createClientCredentialsTokenRequest(registeredClient); MockHttpServletResponse response = new MockHttpServletResponse(); - filter.doFilter(request, response, mock(FilterChain.class)); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + ArgumentCaptor clientCredentialsAuthenticationCaptor = + ArgumentCaptor.forClass(OAuth2ClientCredentialsAuthenticationToken.class); + verify(this.authenticationManager).authenticate(clientCredentialsAuthenticationCaptor.capture()); + + OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication = + clientCredentialsAuthenticationCaptor.getValue(); + assertThat(clientCredentialsAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(clientCredentialsAuthentication.getScopes()).isEqualTo(registeredClient.getScopes()); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); + + OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken(); + assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType()); + assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue()); + assertThat(accessTokenResult.getIssuedAt()).isBetween( + accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1)); + assertThat(accessTokenResult.getExpiresAt()).isBetween( + accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1)); + assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); } private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode, - Consumer requestConsumer) throws Exception { + MockHttpServletRequest request) throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - - MockHttpServletRequest request = createTokenRequest(registeredClient); - requestConsumer.accept(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -336,7 +349,7 @@ public class OAuth2TokenEndpointFilterTests { return this.accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); } - private static MockHttpServletRequest createTokenRequest(RegisteredClient registeredClient) { + private static MockHttpServletRequest createAuthorizationCodeTokenRequest(RegisteredClient registeredClient) { String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; @@ -349,4 +362,16 @@ public class OAuth2TokenEndpointFilterTests { return request; } + + private static MockHttpServletRequest createClientCredentialsTokenRequest(RegisteredClient registeredClient) { + String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); + request.addParameter(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); + + return request; + } }