Add OAuth2Authorization.authorizationGrantType

Issue gh-213
This commit is contained in:
Joe Grandja 2021-02-08 14:57:15 -05:00
parent 41541912e6
commit 7261b40cd5
8 changed files with 60 additions and 1 deletions

View File

@ -24,6 +24,7 @@ import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
@ -39,6 +40,7 @@ import org.springframework.util.Assert;
* @author Krisztian Toth
* @since 0.0.1
* @see RegisteredClient
* @see AuthorizationGrantType
* @see AbstractOAuth2Token
* @see OAuth2AccessToken
* @see OAuth2RefreshToken
@ -47,6 +49,7 @@ public class OAuth2Authorization implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private String registeredClientId;
private String principalName;
private AuthorizationGrantType authorizationGrantType;
private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens;
private Map<String, Object> attributes;
@ -71,6 +74,15 @@ public class OAuth2Authorization implements Serializable {
return this.principalName;
}
/**
* Returns the {@link AuthorizationGrantType authorization grant type} used for the authorization.
*
* @return the {@link AuthorizationGrantType} used for the authorization
*/
public AuthorizationGrantType getAuthorizationGrantType() {
return this.authorizationGrantType;
}
/**
* Returns the {@link Token} of type {@link OAuth2AccessToken}.
*
@ -157,13 +169,15 @@ public class OAuth2Authorization implements Serializable {
OAuth2Authorization that = (OAuth2Authorization) obj;
return Objects.equals(this.registeredClientId, that.registeredClientId) &&
Objects.equals(this.principalName, that.principalName) &&
Objects.equals(this.authorizationGrantType, that.authorizationGrantType) &&
Objects.equals(this.tokens, that.tokens) &&
Objects.equals(this.attributes, that.attributes);
}
@Override
public int hashCode() {
return Objects.hash(this.registeredClientId, this.principalName, this.tokens, this.attributes);
return Objects.hash(this.registeredClientId, this.principalName,
this.authorizationGrantType, this.tokens, this.attributes);
}
/**
@ -187,6 +201,7 @@ public class OAuth2Authorization implements Serializable {
Assert.notNull(authorization, "authorization cannot be null");
return new Builder(authorization.getRegisteredClientId())
.principalName(authorization.getPrincipalName())
.authorizationGrantType(authorization.getAuthorizationGrantType())
.tokens(authorization.tokens)
.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
}
@ -292,6 +307,7 @@ public class OAuth2Authorization implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private final String registeredClientId;
private String principalName;
private AuthorizationGrantType authorizationGrantType;
private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens = new HashMap<>();
private final Map<String, Object> attributes = new HashMap<>();
@ -310,6 +326,17 @@ public class OAuth2Authorization implements Serializable {
return this;
}
/**
* Sets the {@link AuthorizationGrantType authorization grant type} used for the authorization.
*
* @param authorizationGrantType the {@link AuthorizationGrantType}
* @return the {@link Builder}
*/
public Builder authorizationGrantType(AuthorizationGrantType authorizationGrantType) {
this.authorizationGrantType = authorizationGrantType;
return this;
}
/**
* Sets the {@link OAuth2AccessToken access token}.
*
@ -401,10 +428,12 @@ public class OAuth2Authorization implements Serializable {
*/
public OAuth2Authorization build() {
Assert.hasText(this.principalName, "principalName cannot be empty");
Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null");
OAuth2Authorization authorization = new OAuth2Authorization();
authorization.registeredClientId = this.registeredClientId;
authorization.principalName = this.principalName;
authorization.authorizationGrantType = this.authorizationGrantType;
authorization.tokens = Collections.unmodifiableMap(this.tokens);
authorization.attributes = Collections.unmodifiableMap(this.attributes);
return authorization;

View File

@ -124,6 +124,7 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName(clientPrincipal.getName())
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.token(accessToken)
.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
.build();

View File

@ -193,6 +193,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest();
OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName(principal.getName())
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, principal)
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest);

View File

@ -21,6 +21,7 @@ import java.time.temporal.ChronoUnit;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@ -39,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class InMemoryOAuth2AuthorizationServiceTests {
private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
private static final String PRINCIPAL_NAME = "principal";
private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE;
private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode(
"code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES));
private InMemoryOAuth2AuthorizationService authorizationService;
@ -59,6 +61,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void saveWhenAuthorizationProvidedThenSaved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.build();
this.authorizationService.save(expectedAuthorization);
@ -79,6 +82,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void removeWhenAuthorizationProvidedThenRemoved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.build();
@ -105,6 +109,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
String state = "state";
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.attribute(OAuth2AuthorizationAttributeNames.STATE, state)
.build();
this.authorizationService.save(authorization);
@ -120,6 +125,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void findByTokenWhenAuthorizationCodeExistsThenFound() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.build();
this.authorizationService.save(authorization);
@ -137,6 +143,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
"access-token", Instant.now().minusSeconds(60), Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.accessToken(accessToken)
.build();
@ -154,6 +161,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.refreshToken(refreshToken)
.build();
this.authorizationService.save(authorization);

View File

@ -20,6 +20,7 @@ import java.time.temporal.ChronoUnit;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@ -38,6 +39,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class OAuth2AuthorizationTests {
private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
private static final String PRINCIPAL_NAME = "principal";
private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE;
private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
private static final OAuth2RefreshToken REFRESH_TOKEN = new OAuth2RefreshToken("refresh-token", Instant.now());
@ -62,6 +64,7 @@ public class OAuth2AuthorizationTests {
public void fromWhenAuthorizationProvidedThenCopied() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.accessToken(ACCESS_TOKEN)
.build();
@ -69,6 +72,7 @@ public class OAuth2AuthorizationTests {
assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
assertThat(authorizationResult.getAuthorizationGrantType()).isEqualTo(authorization.getAuthorizationGrantType());
assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
assertThat(authorizationResult.getToken(OAuth2AuthorizationCode.class))
.isEqualTo(authorization.getToken(OAuth2AuthorizationCode.class));
@ -82,6 +86,13 @@ public class OAuth2AuthorizationTests {
.hasMessage("principalName cannot be empty");
}
@Test
public void buildWhenAuthorizationGrantTypeNotProvidedThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).principalName(PRINCIPAL_NAME).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationGrantType cannot be null");
}
@Test
public void attributeWhenNameNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() ->
@ -104,6 +115,7 @@ public class OAuth2AuthorizationTests {
public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.accessToken(ACCESS_TOKEN)
.refreshToken(REFRESH_TOKEN)
@ -111,6 +123,7 @@ public class OAuth2AuthorizationTests {
assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AUTHORIZATION_GRANT_TYPE);
assertThat(authorization.getToken(OAuth2AuthorizationCode.class).getToken()).isEqualTo(AUTHORIZATION_CODE);
assertThat(authorization.getAccessToken().getToken()).isEqualTo(ACCESS_TOKEN);
assertThat(authorization.getRefreshToken().getToken()).isEqualTo(REFRESH_TOKEN);

View File

@ -21,6 +21,7 @@ import java.util.Collections;
import java.util.Map;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
@ -61,6 +62,7 @@ public class TestOAuth2Authorizations {
.build();
return OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName("principal")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.token(authorizationCode)
.accessToken(accessToken)
.refreshToken(refreshToken)

View File

@ -204,6 +204,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId());
assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName());
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
assertThat(authorization.getAccessToken()).isNotNull();
assertThat(authorization.getAccessToken().getToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);

View File

@ -467,6 +467,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = authorizationCaptor.getValue();
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication);
@ -516,6 +517,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = authorizationCaptor.getValue();
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication);
@ -563,6 +565,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = authorizationCaptor.getValue();
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication);
@ -795,6 +798,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(updatedAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull();
assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))