diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index a23ac92..5bf8f23 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -24,6 +24,7 @@ import java.util.function.Consumer; import org.springframework.lang.Nullable; import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; @@ -39,6 +40,7 @@ import org.springframework.util.Assert; * @author Krisztian Toth * @since 0.0.1 * @see RegisteredClient + * @see AuthorizationGrantType * @see AbstractOAuth2Token * @see OAuth2AccessToken * @see OAuth2RefreshToken @@ -47,6 +49,7 @@ public class OAuth2Authorization implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String registeredClientId; private String principalName; + private AuthorizationGrantType authorizationGrantType; private Map, Token> tokens; private Map attributes; @@ -71,6 +74,15 @@ public class OAuth2Authorization implements Serializable { return this.principalName; } + /** + * Returns the {@link AuthorizationGrantType authorization grant type} used for the authorization. + * + * @return the {@link AuthorizationGrantType} used for the authorization + */ + public AuthorizationGrantType getAuthorizationGrantType() { + return this.authorizationGrantType; + } + /** * Returns the {@link Token} of type {@link OAuth2AccessToken}. * @@ -157,13 +169,15 @@ public class OAuth2Authorization implements Serializable { OAuth2Authorization that = (OAuth2Authorization) obj; return Objects.equals(this.registeredClientId, that.registeredClientId) && Objects.equals(this.principalName, that.principalName) && + Objects.equals(this.authorizationGrantType, that.authorizationGrantType) && Objects.equals(this.tokens, that.tokens) && Objects.equals(this.attributes, that.attributes); } @Override public int hashCode() { - return Objects.hash(this.registeredClientId, this.principalName, this.tokens, this.attributes); + return Objects.hash(this.registeredClientId, this.principalName, + this.authorizationGrantType, this.tokens, this.attributes); } /** @@ -187,6 +201,7 @@ public class OAuth2Authorization implements Serializable { Assert.notNull(authorization, "authorization cannot be null"); return new Builder(authorization.getRegisteredClientId()) .principalName(authorization.getPrincipalName()) + .authorizationGrantType(authorization.getAuthorizationGrantType()) .tokens(authorization.tokens) .attributes(attrs -> attrs.putAll(authorization.getAttributes())); } @@ -292,6 +307,7 @@ public class OAuth2Authorization implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private final String registeredClientId; private String principalName; + private AuthorizationGrantType authorizationGrantType; private Map, Token> tokens = new HashMap<>(); private final Map attributes = new HashMap<>(); @@ -310,6 +326,17 @@ public class OAuth2Authorization implements Serializable { return this; } + /** + * Sets the {@link AuthorizationGrantType authorization grant type} used for the authorization. + * + * @param authorizationGrantType the {@link AuthorizationGrantType} + * @return the {@link Builder} + */ + public Builder authorizationGrantType(AuthorizationGrantType authorizationGrantType) { + this.authorizationGrantType = authorizationGrantType; + return this; + } + /** * Sets the {@link OAuth2AccessToken access token}. * @@ -401,10 +428,12 @@ public class OAuth2Authorization implements Serializable { */ public OAuth2Authorization build() { Assert.hasText(this.principalName, "principalName cannot be empty"); + Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null"); OAuth2Authorization authorization = new OAuth2Authorization(); authorization.registeredClientId = this.registeredClientId; authorization.principalName = this.principalName; + authorization.authorizationGrantType = this.authorizationGrantType; authorization.tokens = Collections.unmodifiableMap(this.tokens); authorization.attributes = Collections.unmodifiableMap(this.attributes); return authorization; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index a242d8a..f0d3a3a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -124,6 +124,7 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(clientPrincipal.getName()) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .token(accessToken) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .build(); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 8823da8..c3a5a2e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -193,6 +193,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest(); OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(principal.getName()) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, principal) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index c1acb60..b263fd7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -21,6 +21,7 @@ import java.time.temporal.ChronoUnit; import org.junit.Before; import org.junit.Test; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -39,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; public class InMemoryOAuth2AuthorizationServiceTests { private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final String PRINCIPAL_NAME = "principal"; + private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE; private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode( "code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES)); private InMemoryOAuth2AuthorizationService authorizationService; @@ -59,6 +61,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void saveWhenAuthorizationProvidedThenSaved() { OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) .build(); this.authorizationService.save(expectedAuthorization); @@ -79,6 +82,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void removeWhenAuthorizationProvidedThenRemoved() { OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) .build(); @@ -105,6 +109,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { String state = "state"; OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .attribute(OAuth2AuthorizationAttributeNames.STATE, state) .build(); this.authorizationService.save(authorization); @@ -120,6 +125,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void findByTokenWhenAuthorizationCodeExistsThenFound() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) .build(); this.authorizationService.save(authorization); @@ -137,6 +143,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { "access-token", Instant.now().minusSeconds(60), Instant.now()); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) .accessToken(accessToken) .build(); @@ -154,6 +161,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now()); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .refreshToken(refreshToken) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index 34359f5..079f8f4 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -20,6 +20,7 @@ import java.time.temporal.ChronoUnit; import org.junit.Test; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -38,6 +39,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; public class OAuth2AuthorizationTests { private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final String PRINCIPAL_NAME = "principal"; + private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE; private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); private static final OAuth2RefreshToken REFRESH_TOKEN = new OAuth2RefreshToken("refresh-token", Instant.now()); @@ -62,6 +64,7 @@ public class OAuth2AuthorizationTests { public void fromWhenAuthorizationProvidedThenCopied() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) .accessToken(ACCESS_TOKEN) .build(); @@ -69,6 +72,7 @@ public class OAuth2AuthorizationTests { assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); + assertThat(authorizationResult.getAuthorizationGrantType()).isEqualTo(authorization.getAuthorizationGrantType()); assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken()); assertThat(authorizationResult.getToken(OAuth2AuthorizationCode.class)) .isEqualTo(authorization.getToken(OAuth2AuthorizationCode.class)); @@ -82,6 +86,13 @@ public class OAuth2AuthorizationTests { .hasMessage("principalName cannot be empty"); } + @Test + public void buildWhenAuthorizationGrantTypeNotProvidedThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).principalName(PRINCIPAL_NAME).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationGrantType cannot be null"); + } + @Test public void attributeWhenNameNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> @@ -104,6 +115,7 @@ public class OAuth2AuthorizationTests { public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) .accessToken(ACCESS_TOKEN) .refreshToken(REFRESH_TOKEN) @@ -111,6 +123,7 @@ public class OAuth2AuthorizationTests { assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AUTHORIZATION_GRANT_TYPE); assertThat(authorization.getToken(OAuth2AuthorizationCode.class).getToken()).isEqualTo(AUTHORIZATION_CODE); assertThat(authorization.getAccessToken().getToken()).isEqualTo(ACCESS_TOKEN); assertThat(authorization.getRefreshToken().getToken()).isEqualTo(REFRESH_TOKEN); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 4629f5c..e2437f4 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.Map; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; @@ -61,6 +62,7 @@ public class TestOAuth2Authorizations { .build(); return OAuth2Authorization.withRegisteredClient(registeredClient) .principalName("principal") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .token(authorizationCode) .accessToken(accessToken) .refreshToken(refreshToken) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index 2690553..ac612e8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -204,6 +204,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId()); assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName()); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); assertThat(authorization.getAccessToken()).isNotNull(); assertThat(authorization.getAccessToken().getToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes()); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 6e8e759..d75dca7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -467,6 +467,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = authorizationCaptor.getValue(); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) .isEqualTo(this.authentication); @@ -516,6 +517,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = authorizationCaptor.getValue(); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) .isEqualTo(this.authentication); @@ -563,6 +565,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = authorizationCaptor.getValue(); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) .isEqualTo(this.authentication); @@ -795,6 +798,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(updatedAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))