diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java index 1738029..73357a1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java @@ -82,15 +82,22 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide throwInvalidClient(); } + boolean authenticatedCredentials = false; + if (clientAuthentication.getCredentials() != null) { String clientSecret = clientAuthentication.getCredentials().toString(); // TODO Use PasswordEncoder.matches() if (!registeredClient.getClientSecret().equals(clientSecret)) { throwInvalidClient(); } + authenticatedCredentials = true; } - authenticatePkceIfAvailable(clientAuthentication, registeredClient); + authenticatedCredentials = authenticatedCredentials || + authenticatePkceIfAvailable(clientAuthentication, registeredClient); + if (!authenticatedCredentials) { + throwInvalidClient(); + } return new OAuth2ClientAuthenticationToken(registeredClient); } @@ -100,12 +107,12 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); } - private void authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication, + private boolean authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication, RegisteredClient registeredClient) { Map parameters = clientAuthentication.getAdditionalParameters(); if (CollectionUtils.isEmpty(parameters) || !authorizationCodeGrant(parameters)) { - return; + return false; } OAuth2Authorization authorization = this.authorizationService.findByToken( @@ -120,16 +127,19 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide String codeChallenge = (String) authorizationRequest.getAdditionalParameters() .get(PkceParameterNames.CODE_CHALLENGE); - if (StringUtils.hasText(codeChallenge)) { - String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters() - .get(PkceParameterNames.CODE_CHALLENGE_METHOD); - String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER); - if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { - throwInvalidClient(); - } - } else if (registeredClient.getClientSettings().requireProofKey()) { + if (!StringUtils.hasText(codeChallenge) && + registeredClient.getClientSettings().requireProofKey()) { throwInvalidClient(); } + + String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters() + .get(PkceParameterNames.CODE_CHALLENGE_METHOD); + String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER); + if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { + throwInvalidClient(); + } + + return true; } private static boolean authorizationCodeGrant(Map parameters) { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java index 66bb4f6..dd3566f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java @@ -37,7 +37,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** @@ -119,6 +118,21 @@ public class OAuth2ClientAuthenticationProviderTests { .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } + @Test + public void authenticateWhenClientSecretNotProvidedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + @Test public void authenticateWhenValidCredentialsThenAuthenticated() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -135,21 +149,6 @@ public class OAuth2ClientAuthenticationProviderTests { assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); } - @Test - public void authenticateWhenNotPkceThenContinueAuthenticated() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( - registeredClient.getClientId(), registeredClient.getClientSecret(), null); - OAuth2ClientAuthenticationToken authenticationResult = - (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); - assertThat(authenticationResult.isAuthenticated()).isTrue(); - - verifyNoInteractions(this.authorizationService); - } - @Test public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();