diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index e35ea09..408b39a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -80,6 +81,10 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica } RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); + if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.CLIENT_CREDENTIALS)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)); + } + Set scopes = registeredClient.getScopes(); // Default to configured scopes if (!CollectionUtils.isEmpty(clientCredentialsAuthentication.getScopes())) { Set unauthorizedScopes = clientCredentialsAuthentication.getScopes().stream() diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index da2f58d..dbef6db 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -19,6 +19,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.jose.JoseHeaderNames; @@ -49,14 +50,12 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class OAuth2ClientCredentialsAuthenticationProviderTests { - private RegisteredClient registeredClient; private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider; @Before public void setUp() { - this.registeredClient = TestRegisteredClients.registeredClient().build(); this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider( @@ -89,8 +88,9 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { @Test public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( - this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + registeredClient.getClientId(), registeredClient.getClientSecret()); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) @@ -102,8 +102,9 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { @Test public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( - this.registeredClient.getClientId(), this.registeredClient.getClientSecret(), null); + registeredClient.getClientId(), registeredClient.getClientSecret(), null); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) @@ -113,9 +114,25 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } + @Test + public void authenticateWhenClientNotAuthorizedToRequestTokenThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2() + .authorizationGrantTypes(grantTypes -> grantTypes.remove(AuthorizationGrantType.CLIENT_CREDENTIALS)) + .build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT); + } + @Test public void authenticateWhenInvalidScopeThenThrowOAuth2AuthenticationException() { - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken( clientPrincipal, Collections.singleton("invalid-scope")); @@ -128,7 +145,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { @Test public void authenticateWhenScopeRequestedThenAccessTokenContainsScope() { - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); Set requestedScope = Collections.singleton("openid"); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope); @@ -142,7 +160,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { @Test public void authenticateWhenValidAuthenticationThenReturnAccessToken() { - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());