From 7652d0ebbe2205aeb9b851b7a41580c1b36c9a00 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 11 Feb 2021 09:50:04 -0500 Subject: [PATCH] Propagate additional token request parameters Closes gh-226 --- ...AuthorizationGrantAuthenticationToken.java | 3 +- ...2ClientCredentialsAuthenticationToken.java | 24 +++++------- ...OAuth2RefreshTokenAuthenticationToken.java | 22 ++++------- .../web/OAuth2TokenEndpointFilter.java | 38 +++++++++++++++---- ...redentialsAuthenticationProviderTests.java | 16 +++++--- ...ntCredentialsAuthenticationTokenTests.java | 24 ++++++------ ...freshTokenAuthenticationProviderTests.java | 22 +++++------ ...2RefreshTokenAuthenticationTokenTests.java | 27 ++++++------- .../web/OAuth2TokenEndpointFilterTests.java | 33 ++++++++++------ 9 files changed, 115 insertions(+), 94 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationGrantAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationGrantAuthenticationToken.java index 9ba7758..d527705 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationGrantAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationGrantAuthenticationToken.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.springframework.lang.Nullable; @@ -57,7 +58,7 @@ public class OAuth2AuthorizationGrantAuthenticationToken extends AbstractAuthent this.clientPrincipal = clientPrincipal; this.additionalParameters = Collections.unmodifiableMap( additionalParameters != null ? - additionalParameters : + new HashMap<>(additionalParameters) : Collections.emptyMap()); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java index 27d9091..8384948 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java @@ -16,12 +16,13 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.util.Collections; -import java.util.LinkedHashSet; +import java.util.HashSet; +import java.util.Map; import java.util.Set; +import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.util.Assert; /** * An {@link Authentication} implementation used for the OAuth 2.0 Client Credentials Grant. @@ -34,25 +35,18 @@ import org.springframework.util.Assert; public class OAuth2ClientCredentialsAuthenticationToken extends OAuth2AuthorizationGrantAuthenticationToken { private final Set scopes; - /** - * Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters. - * - * @param clientPrincipal the authenticated client principal - */ - public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal) { - this(clientPrincipal, Collections.emptySet()); - } - /** * Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters. * * @param clientPrincipal the authenticated client principal * @param scopes the requested scope(s) + * @param additionalParameters the additional parameters */ - public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal, Set scopes) { - super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientPrincipal, null); - Assert.notNull(scopes, "scopes cannot be null"); - this.scopes = Collections.unmodifiableSet(new LinkedHashSet<>(scopes)); + public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal, + @Nullable Set scopes, @Nullable Map additionalParameters) { + super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientPrincipal, additionalParameters); + this.scopes = Collections.unmodifiableSet( + scopes != null ? new HashSet<>(scopes) : Collections.emptySet()); } /** diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java index 9e20426..26ce4ad 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java @@ -16,8 +16,11 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.util.Collections; +import java.util.HashSet; +import java.util.Map; import java.util.Set; +import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; @@ -34,30 +37,21 @@ public class OAuth2RefreshTokenAuthenticationToken extends OAuth2AuthorizationGr private final String refreshToken; private final Set scopes; - /** - * Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters. - * - * @param refreshToken the refresh token - * @param clientPrincipal the authenticated client principal - */ - public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal) { - this(refreshToken, clientPrincipal, Collections.emptySet()); - } - /** * Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters. * * @param refreshToken the refresh token * @param clientPrincipal the authenticated client principal * @param scopes the requested scope(s) + * @param additionalParameters the additional parameters */ public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal, - Set scopes) { - super(AuthorizationGrantType.REFRESH_TOKEN, clientPrincipal, null); + @Nullable Set scopes, @Nullable Map additionalParameters) { + super(AuthorizationGrantType.REFRESH_TOKEN, clientPrincipal, additionalParameters); Assert.hasText(refreshToken, "refreshToken cannot be empty"); - Assert.notNull(scopes, "scopes cannot be null"); this.refreshToken = refreshToken; - this.scopes = scopes; + this.scopes = Collections.unmodifiableSet( + scopes != null ? new HashSet<>(scopes) : Collections.emptySet()); } /** diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 6766a03..de05fdc 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -229,6 +229,7 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); } + // @formatter:off Map additionalParameters = parameters .entrySet() .stream() @@ -237,8 +238,10 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { !e.getKey().equals(OAuth2ParameterNames.CODE) && !e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI)) .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0))); + // @formatter:on - return new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters); + return new OAuth2AuthorizationCodeAuthenticationToken( + code, clientPrincipal, redirectUri, additionalParameters); } } @@ -269,13 +272,24 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE); } + Set requestedScopes = null; if (StringUtils.hasText(scope)) { - Set requestedScopes = new HashSet<>( + requestedScopes = new HashSet<>( Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); - return new OAuth2RefreshTokenAuthenticationToken(refreshToken, clientPrincipal, requestedScopes); } - return new OAuth2RefreshTokenAuthenticationToken(refreshToken, clientPrincipal); + // @formatter:off + Map additionalParameters = parameters + .entrySet() + .stream() + .filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) && + !e.getKey().equals(OAuth2ParameterNames.REFRESH_TOKEN) && + !e.getKey().equals(OAuth2ParameterNames.SCOPE)) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0))); + // @formatter:on + + return new OAuth2RefreshTokenAuthenticationToken( + refreshToken, clientPrincipal, requestedScopes, additionalParameters); } } @@ -299,13 +313,23 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE); } + Set requestedScopes = null; if (StringUtils.hasText(scope)) { - Set requestedScopes = new HashSet<>( + requestedScopes = new HashSet<>( Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); - return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScopes); } - return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + // @formatter:off + Map additionalParameters = parameters + .entrySet() + .stream() + .filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) && + !e.getKey().equals(OAuth2ParameterNames.SCOPE)) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0))); + // @formatter:on + + return new OAuth2ClientCredentialsAuthenticationToken( + clientPrincipal, requestedScopes, additionalParameters); } } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index 22306fe..f85b279 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -108,7 +108,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( registeredClient.getClientId(), registeredClient.getClientSecret()); - OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -122,7 +123,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null); - OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -137,7 +139,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { .authorizationGrantTypes(grantTypes -> grantTypes.remove(AuthorizationGrantType.CLIENT_CREDENTIALS)) .build(); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); - OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -151,7 +154,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( - clientPrincipal, Collections.singleton("invalid-scope")); + clientPrincipal, Collections.singleton("invalid-scope"), null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -166,7 +169,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); Set requestedScope = Collections.singleton("scope1"); OAuth2ClientCredentialsAuthenticationToken authentication = - new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope); + new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope, null); when(this.jwtEncoder.encode(any(), any())) .thenReturn(createJwt(Collections.singleton("mapped-scoped"))); @@ -180,7 +183,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { public void authenticateWhenValidAuthenticationThenReturnAccessToken() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); - OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(registeredClient.getScopes())); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java index 73c6151..78d960a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.util.Collections; +import java.util.Map; import java.util.Set; import org.junit.Test; @@ -34,42 +35,39 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; public class OAuth2ClientCredentialsAuthenticationTokenTests { private final OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); + private Set scopes = Collections.singleton("scope1"); + private Map additionalParameters = Collections.singletonMap("param1", "value1"); @Test public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(null)) + assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(null, this.scopes, this.additionalParameters)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientPrincipal cannot be null"); } - @Test - public void constructorWhenScopesNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("scopes cannot be null"); - } - @Test public void constructorWhenClientPrincipalProvidedThenCreated() { - OAuth2ClientCredentialsAuthenticationToken authentication = - new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal); + OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( + this.clientPrincipal, this.scopes, this.additionalParameters); assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getCredentials().toString()).isEmpty(); - assertThat(authentication.getScopes()).isEmpty(); + assertThat(authentication.getScopes()).isEqualTo(this.scopes); + assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters); } @Test public void constructorWhenScopesProvidedThenCreated() { Set expectedScopes = Collections.singleton("test-scope"); - OAuth2ClientCredentialsAuthenticationToken authentication = - new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, expectedScopes); + OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( + this.clientPrincipal, expectedScopes, this.additionalParameters); assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getScopes()).isEqualTo(expectedScopes); + assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java index f3b51bc..c12e2e1 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -124,7 +124,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -169,7 +169,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -199,7 +199,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { Set requestedScopes = new HashSet<>(authorizedScopes); requestedScopes.remove("scope1"); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes, null); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -221,7 +221,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { Set requestedScopes = new HashSet<>(authorizedScopes); requestedScopes.add("unauthorized"); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -235,7 +235,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - "invalid", clientPrincipal); + "invalid", clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -250,7 +250,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( registeredClient.getClientId(), registeredClient.getClientSecret()); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - "refresh-token", clientPrincipal); + "refresh-token", clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -265,7 +265,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - "refresh-token", clientPrincipal); + "refresh-token", clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -286,7 +286,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( TestRegisteredClients.registeredClient2().build()); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -308,7 +308,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -331,7 +331,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -355,7 +355,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java index ef9603b..8cd5c4f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java @@ -15,8 +15,8 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; -import java.util.Arrays; -import java.util.HashSet; +import java.util.Collections; +import java.util.Map; import java.util.Set; import org.junit.Test; @@ -34,42 +34,37 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @since 0.0.3 */ public class OAuth2RefreshTokenAuthenticationTokenTests { - private final OAuth2ClientAuthenticationToken clientPrincipal = + private OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); + private Set scopes = Collections.singleton("scope1"); + private Map additionalParameters = Collections.singletonMap("param1", "value1"); @Test public void constructorWhenRefreshTokenNullOrEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, this.clientPrincipal)) + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, this.clientPrincipal, this.scopes, this.additionalParameters)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("refreshToken cannot be empty"); - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", this.clientPrincipal)) + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", this.clientPrincipal, this.scopes, this.additionalParameters)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("refreshToken cannot be empty"); } @Test public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", null)) + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", null, this.scopes, this.additionalParameters)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientPrincipal cannot be null"); } - @Test - public void constructorWhenScopesNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", this.clientPrincipal, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("scopes cannot be null"); - } - @Test public void constructorWhenScopesProvidedThenCreated() { - Set expectedScopes = new HashSet<>(Arrays.asList("scope-a", "scope-b")); OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( - "refresh-token", this.clientPrincipal, expectedScopes); + "refresh-token", this.clientPrincipal, this.scopes, this.additionalParameters); assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); assertThat(authentication.getRefreshToken()).isEqualTo("refresh-token"); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getCredentials().toString()).isEmpty(); - assertThat(authentication.getScopes()).isEqualTo(expectedScopes); + assertThat(authentication.getScopes()).isEqualTo(this.scopes); + assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index d2a97d5..5049b11 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,10 +15,22 @@ */ package org.springframework.security.oauth2.server.authorization.web; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.mock.http.client.MockClientHttpResponse; @@ -47,16 +59,6 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.util.StringUtils; -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.entry; @@ -232,6 +234,8 @@ public class OAuth2TokenEndpointFilterTests { assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo( request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); + assertThat(authorizationCodeAuthentication.getAdditionalParameters()) + .containsExactly(entry("custom-param-1", "custom-value-1")); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); @@ -292,6 +296,8 @@ public class OAuth2TokenEndpointFilterTests { clientCredentialsAuthenticationCaptor.getValue(); assertThat(clientCredentialsAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(clientCredentialsAuthentication.getScopes()).isEqualTo(registeredClient.getScopes()); + assertThat(clientCredentialsAuthentication.getAdditionalParameters()) + .containsExactly(entry("custom-param-1", "custom-value-1")); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); @@ -372,6 +378,8 @@ public class OAuth2TokenEndpointFilterTests { assertThat(refreshTokenAuthenticationToken.getRefreshToken()).isEqualTo(refreshToken.getTokenValue()); assertThat(refreshTokenAuthenticationToken.getPrincipal()).isEqualTo(clientPrincipal); assertThat(refreshTokenAuthenticationToken.getScopes()).isEqualTo(registeredClient.getScopes()); + assertThat(refreshTokenAuthenticationToken.getAdditionalParameters()) + .containsExactly(entry("custom-param-1", "custom-value-1")); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); @@ -429,6 +437,7 @@ public class OAuth2TokenEndpointFilterTests { request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]); // The client does not need to send the client ID param, but we are resilient in case they do request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + request.addParameter("custom-param-1", "custom-value-1"); return request; } @@ -441,6 +450,7 @@ public class OAuth2TokenEndpointFilterTests { request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); request.addParameter(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); + request.addParameter("custom-param-1", "custom-value-1"); return request; } @@ -454,6 +464,7 @@ public class OAuth2TokenEndpointFilterTests { request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, "refresh-token"); request.addParameter(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); + request.addParameter("custom-param-1", "custom-value-1"); return request; }