Propagate additional token request parameters

Closes gh-226
This commit is contained in:
Joe Grandja 2021-02-11 09:50:04 -05:00
parent b5d47366ad
commit 7652d0ebbe
9 changed files with 115 additions and 94 deletions

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
@ -57,7 +58,7 @@ public class OAuth2AuthorizationGrantAuthenticationToken extends AbstractAuthent
this.clientPrincipal = clientPrincipal; this.clientPrincipal = clientPrincipal;
this.additionalParameters = Collections.unmodifiableMap( this.additionalParameters = Collections.unmodifiableMap(
additionalParameters != null ? additionalParameters != null ?
additionalParameters : new HashMap<>(additionalParameters) :
Collections.emptyMap()); Collections.emptyMap());
} }

View File

@ -16,12 +16,13 @@
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashSet; import java.util.HashSet;
import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
/** /**
* An {@link Authentication} implementation used for the OAuth 2.0 Client Credentials Grant. * An {@link Authentication} implementation used for the OAuth 2.0 Client Credentials Grant.
@ -34,25 +35,18 @@ import org.springframework.util.Assert;
public class OAuth2ClientCredentialsAuthenticationToken extends OAuth2AuthorizationGrantAuthenticationToken { public class OAuth2ClientCredentialsAuthenticationToken extends OAuth2AuthorizationGrantAuthenticationToken {
private final Set<String> scopes; private final Set<String> scopes;
/**
* Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters.
*
* @param clientPrincipal the authenticated client principal
*/
public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal) {
this(clientPrincipal, Collections.emptySet());
}
/** /**
* Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters. * Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters.
* *
* @param clientPrincipal the authenticated client principal * @param clientPrincipal the authenticated client principal
* @param scopes the requested scope(s) * @param scopes the requested scope(s)
* @param additionalParameters the additional parameters
*/ */
public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal, Set<String> scopes) { public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal,
super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientPrincipal, null); @Nullable Set<String> scopes, @Nullable Map<String, Object> additionalParameters) {
Assert.notNull(scopes, "scopes cannot be null"); super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientPrincipal, additionalParameters);
this.scopes = Collections.unmodifiableSet(new LinkedHashSet<>(scopes)); this.scopes = Collections.unmodifiableSet(
scopes != null ? new HashSet<>(scopes) : Collections.emptySet());
} }
/** /**

View File

@ -16,8 +16,11 @@
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -34,30 +37,21 @@ public class OAuth2RefreshTokenAuthenticationToken extends OAuth2AuthorizationGr
private final String refreshToken; private final String refreshToken;
private final Set<String> scopes; private final Set<String> scopes;
/**
* Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters.
*
* @param refreshToken the refresh token
* @param clientPrincipal the authenticated client principal
*/
public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal) {
this(refreshToken, clientPrincipal, Collections.emptySet());
}
/** /**
* Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters. * Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters.
* *
* @param refreshToken the refresh token * @param refreshToken the refresh token
* @param clientPrincipal the authenticated client principal * @param clientPrincipal the authenticated client principal
* @param scopes the requested scope(s) * @param scopes the requested scope(s)
* @param additionalParameters the additional parameters
*/ */
public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal, public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal,
Set<String> scopes) { @Nullable Set<String> scopes, @Nullable Map<String, Object> additionalParameters) {
super(AuthorizationGrantType.REFRESH_TOKEN, clientPrincipal, null); super(AuthorizationGrantType.REFRESH_TOKEN, clientPrincipal, additionalParameters);
Assert.hasText(refreshToken, "refreshToken cannot be empty"); Assert.hasText(refreshToken, "refreshToken cannot be empty");
Assert.notNull(scopes, "scopes cannot be null");
this.refreshToken = refreshToken; this.refreshToken = refreshToken;
this.scopes = scopes; this.scopes = Collections.unmodifiableSet(
scopes != null ? new HashSet<>(scopes) : Collections.emptySet());
} }
/** /**

View File

@ -229,6 +229,7 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
} }
// @formatter:off
Map<String, Object> additionalParameters = parameters Map<String, Object> additionalParameters = parameters
.entrySet() .entrySet()
.stream() .stream()
@ -237,8 +238,10 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
!e.getKey().equals(OAuth2ParameterNames.CODE) && !e.getKey().equals(OAuth2ParameterNames.CODE) &&
!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI)) !e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI))
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0))); .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
// @formatter:on
return new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters); return new OAuth2AuthorizationCodeAuthenticationToken(
code, clientPrincipal, redirectUri, additionalParameters);
} }
} }
@ -269,13 +272,24 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) { parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE); throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE);
} }
Set<String> requestedScopes = null;
if (StringUtils.hasText(scope)) { if (StringUtils.hasText(scope)) {
Set<String> requestedScopes = new HashSet<>( requestedScopes = new HashSet<>(
Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
return new OAuth2RefreshTokenAuthenticationToken(refreshToken, clientPrincipal, requestedScopes);
} }
return new OAuth2RefreshTokenAuthenticationToken(refreshToken, clientPrincipal); // @formatter:off
Map<String, Object> additionalParameters = parameters
.entrySet()
.stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
!e.getKey().equals(OAuth2ParameterNames.REFRESH_TOKEN) &&
!e.getKey().equals(OAuth2ParameterNames.SCOPE))
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
// @formatter:on
return new OAuth2RefreshTokenAuthenticationToken(
refreshToken, clientPrincipal, requestedScopes, additionalParameters);
} }
} }
@ -299,13 +313,23 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) { parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE); throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE);
} }
Set<String> requestedScopes = null;
if (StringUtils.hasText(scope)) { if (StringUtils.hasText(scope)) {
Set<String> requestedScopes = new HashSet<>( requestedScopes = new HashSet<>(
Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScopes);
} }
return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); // @formatter:off
Map<String, Object> additionalParameters = parameters
.entrySet()
.stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
!e.getKey().equals(OAuth2ParameterNames.SCOPE))
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
// @formatter:on
return new OAuth2ClientCredentialsAuthenticationToken(
clientPrincipal, requestedScopes, additionalParameters);
} }
} }
} }

View File

@ -108,7 +108,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret()); registeredClient.getClientId(), registeredClient.getClientSecret());
OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); OAuth2ClientCredentialsAuthenticationToken authentication =
new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -122,7 +123,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null); registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null);
OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); OAuth2ClientCredentialsAuthenticationToken authentication =
new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -137,7 +139,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
.authorizationGrantTypes(grantTypes -> grantTypes.remove(AuthorizationGrantType.CLIENT_CREDENTIALS)) .authorizationGrantTypes(grantTypes -> grantTypes.remove(AuthorizationGrantType.CLIENT_CREDENTIALS))
.build(); .build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); OAuth2ClientCredentialsAuthenticationToken authentication =
new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -151,7 +154,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(
clientPrincipal, Collections.singleton("invalid-scope")); clientPrincipal, Collections.singleton("invalid-scope"), null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -166,7 +169,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
Set<String> requestedScope = Collections.singleton("scope1"); Set<String> requestedScope = Collections.singleton("scope1");
OAuth2ClientCredentialsAuthenticationToken authentication = OAuth2ClientCredentialsAuthenticationToken authentication =
new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope); new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope, null);
when(this.jwtEncoder.encode(any(), any())) when(this.jwtEncoder.encode(any(), any()))
.thenReturn(createJwt(Collections.singleton("mapped-scoped"))); .thenReturn(createJwt(Collections.singleton("mapped-scoped")));
@ -180,7 +183,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
public void authenticateWhenValidAuthenticationThenReturnAccessToken() { public void authenticateWhenValidAuthenticationThenReturnAccessToken() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); OAuth2ClientCredentialsAuthenticationToken authentication =
new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(registeredClient.getScopes())); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(registeredClient.getScopes()));

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import java.util.Set; import java.util.Set;
import org.junit.Test; import org.junit.Test;
@ -34,42 +35,39 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class OAuth2ClientCredentialsAuthenticationTokenTests { public class OAuth2ClientCredentialsAuthenticationTokenTests {
private final OAuth2ClientAuthenticationToken clientPrincipal = private final OAuth2ClientAuthenticationToken clientPrincipal =
new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
private Set<String> scopes = Collections.singleton("scope1");
private Map<String, Object> additionalParameters = Collections.singletonMap("param1", "value1");
@Test @Test
public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(null)) assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(null, this.scopes, this.additionalParameters))
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientPrincipal cannot be null"); .hasMessage("clientPrincipal cannot be null");
} }
@Test
public void constructorWhenScopesNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("scopes cannot be null");
}
@Test @Test
public void constructorWhenClientPrincipalProvidedThenCreated() { public void constructorWhenClientPrincipalProvidedThenCreated() {
OAuth2ClientCredentialsAuthenticationToken authentication = OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(
new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal); this.clientPrincipal, this.scopes, this.additionalParameters);
assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getScopes()).isEmpty(); assertThat(authentication.getScopes()).isEqualTo(this.scopes);
assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters);
} }
@Test @Test
public void constructorWhenScopesProvidedThenCreated() { public void constructorWhenScopesProvidedThenCreated() {
Set<String> expectedScopes = Collections.singleton("test-scope"); Set<String> expectedScopes = Collections.singleton("test-scope");
OAuth2ClientCredentialsAuthenticationToken authentication = OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(
new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, expectedScopes); this.clientPrincipal, expectedScopes, this.additionalParameters);
assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getScopes()).isEqualTo(expectedScopes); assertThat(authentication.getScopes()).isEqualTo(expectedScopes);
assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters);
} }
} }

View File

@ -124,7 +124,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -169,7 +169,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -199,7 +199,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
Set<String> requestedScopes = new HashSet<>(authorizedScopes); Set<String> requestedScopes = new HashSet<>(authorizedScopes);
requestedScopes.remove("scope1"); requestedScopes.remove("scope1");
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes, null);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -221,7 +221,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
Set<String> requestedScopes = new HashSet<>(authorizedScopes); Set<String> requestedScopes = new HashSet<>(authorizedScopes);
requestedScopes.add("unauthorized"); requestedScopes.add("unauthorized");
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -235,7 +235,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
"invalid", clientPrincipal); "invalid", clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -250,7 +250,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret()); registeredClient.getClientId(), registeredClient.getClientSecret());
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
"refresh-token", clientPrincipal); "refresh-token", clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -265,7 +265,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null); registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
"refresh-token", clientPrincipal); "refresh-token", clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -286,7 +286,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
TestRegisteredClients.registeredClient2().build()); TestRegisteredClients.registeredClient2().build());
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -308,7 +308,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -331,7 +331,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -355,7 +355,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)

View File

@ -15,8 +15,8 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Arrays; import java.util.Collections;
import java.util.HashSet; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.junit.Test; import org.junit.Test;
@ -34,42 +34,37 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
* @since 0.0.3 * @since 0.0.3
*/ */
public class OAuth2RefreshTokenAuthenticationTokenTests { public class OAuth2RefreshTokenAuthenticationTokenTests {
private final OAuth2ClientAuthenticationToken clientPrincipal = private OAuth2ClientAuthenticationToken clientPrincipal =
new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
private Set<String> scopes = Collections.singleton("scope1");
private Map<String, Object> additionalParameters = Collections.singletonMap("param1", "value1");
@Test @Test
public void constructorWhenRefreshTokenNullOrEmptyThenThrowIllegalArgumentException() { public void constructorWhenRefreshTokenNullOrEmptyThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, this.clientPrincipal)) assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, this.clientPrincipal, this.scopes, this.additionalParameters))
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("refreshToken cannot be empty"); .hasMessage("refreshToken cannot be empty");
assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", this.clientPrincipal)) assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", this.clientPrincipal, this.scopes, this.additionalParameters))
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("refreshToken cannot be empty"); .hasMessage("refreshToken cannot be empty");
} }
@Test @Test
public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", null)) assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", null, this.scopes, this.additionalParameters))
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientPrincipal cannot be null"); .hasMessage("clientPrincipal cannot be null");
} }
@Test
public void constructorWhenScopesNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", this.clientPrincipal, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("scopes cannot be null");
}
@Test @Test
public void constructorWhenScopesProvidedThenCreated() { public void constructorWhenScopesProvidedThenCreated() {
Set<String> expectedScopes = new HashSet<>(Arrays.asList("scope-a", "scope-b"));
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
"refresh-token", this.clientPrincipal, expectedScopes); "refresh-token", this.clientPrincipal, this.scopes, this.additionalParameters);
assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
assertThat(authentication.getRefreshToken()).isEqualTo("refresh-token"); assertThat(authentication.getRefreshToken()).isEqualTo("refresh-token");
assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
assertThat(authentication.getCredentials().toString()).isEmpty(); assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getScopes()).isEqualTo(expectedScopes); assertThat(authentication.getScopes()).isEqualTo(this.scopes);
assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,10 +15,22 @@
*/ */
package org.springframework.security.oauth2.server.authorization.web; package org.springframework.security.oauth2.server.authorization.web;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.http.client.MockClientHttpResponse;
@ -47,16 +59,6 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.Assertions.entry;
@ -232,6 +234,8 @@ public class OAuth2TokenEndpointFilterTests {
assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo( assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo(
request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); request.getParameter(OAuth2ParameterNames.REDIRECT_URI));
assertThat(authorizationCodeAuthentication.getAdditionalParameters())
.containsExactly(entry("custom-param-1", "custom-value-1"));
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
@ -292,6 +296,8 @@ public class OAuth2TokenEndpointFilterTests {
clientCredentialsAuthenticationCaptor.getValue(); clientCredentialsAuthenticationCaptor.getValue();
assertThat(clientCredentialsAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(clientCredentialsAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(clientCredentialsAuthentication.getScopes()).isEqualTo(registeredClient.getScopes()); assertThat(clientCredentialsAuthentication.getScopes()).isEqualTo(registeredClient.getScopes());
assertThat(clientCredentialsAuthentication.getAdditionalParameters())
.containsExactly(entry("custom-param-1", "custom-value-1"));
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
@ -372,6 +378,8 @@ public class OAuth2TokenEndpointFilterTests {
assertThat(refreshTokenAuthenticationToken.getRefreshToken()).isEqualTo(refreshToken.getTokenValue()); assertThat(refreshTokenAuthenticationToken.getRefreshToken()).isEqualTo(refreshToken.getTokenValue());
assertThat(refreshTokenAuthenticationToken.getPrincipal()).isEqualTo(clientPrincipal); assertThat(refreshTokenAuthenticationToken.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(refreshTokenAuthenticationToken.getScopes()).isEqualTo(registeredClient.getScopes()); assertThat(refreshTokenAuthenticationToken.getScopes()).isEqualTo(registeredClient.getScopes());
assertThat(refreshTokenAuthenticationToken.getAdditionalParameters())
.containsExactly(entry("custom-param-1", "custom-value-1"));
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
@ -429,6 +437,7 @@ public class OAuth2TokenEndpointFilterTests {
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]); request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
// The client does not need to send the client ID param, but we are resilient in case they do // The client does not need to send the client ID param, but we are resilient in case they do
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
request.addParameter("custom-param-1", "custom-value-1");
return request; return request;
} }
@ -441,6 +450,7 @@ public class OAuth2TokenEndpointFilterTests {
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
request.addParameter(OAuth2ParameterNames.SCOPE, request.addParameter(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
request.addParameter("custom-param-1", "custom-value-1");
return request; return request;
} }
@ -454,6 +464,7 @@ public class OAuth2TokenEndpointFilterTests {
request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, "refresh-token"); request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, "refresh-token");
request.addParameter(OAuth2ParameterNames.SCOPE, request.addParameter(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
request.addParameter("custom-param-1", "custom-value-1");
return request; return request;
} }