Polish OAuth2ClientAuthenticationProvider

Commit 5c31fb1b7e
This commit is contained in:
Joe Grandja 2020-11-05 14:07:13 -05:00
parent 6a2c841d06
commit 7720e275e4
2 changed files with 36 additions and 27 deletions

View File

@ -82,15 +82,22 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide
throwInvalidClient(); throwInvalidClient();
} }
boolean authenticatedCredentials = false;
if (clientAuthentication.getCredentials() != null) { if (clientAuthentication.getCredentials() != null) {
String clientSecret = clientAuthentication.getCredentials().toString(); String clientSecret = clientAuthentication.getCredentials().toString();
// TODO Use PasswordEncoder.matches() // TODO Use PasswordEncoder.matches()
if (!registeredClient.getClientSecret().equals(clientSecret)) { if (!registeredClient.getClientSecret().equals(clientSecret)) {
throwInvalidClient(); throwInvalidClient();
} }
authenticatedCredentials = true;
} }
authenticatedCredentials = authenticatedCredentials ||
authenticatePkceIfAvailable(clientAuthentication, registeredClient); authenticatePkceIfAvailable(clientAuthentication, registeredClient);
if (!authenticatedCredentials) {
throwInvalidClient();
}
return new OAuth2ClientAuthenticationToken(registeredClient); return new OAuth2ClientAuthenticationToken(registeredClient);
} }
@ -100,12 +107,12 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
} }
private void authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication, private boolean authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication,
RegisteredClient registeredClient) { RegisteredClient registeredClient) {
Map<String, Object> parameters = clientAuthentication.getAdditionalParameters(); Map<String, Object> parameters = clientAuthentication.getAdditionalParameters();
if (CollectionUtils.isEmpty(parameters) || !authorizationCodeGrant(parameters)) { if (CollectionUtils.isEmpty(parameters) || !authorizationCodeGrant(parameters)) {
return; return false;
} }
OAuth2Authorization authorization = this.authorizationService.findByToken( OAuth2Authorization authorization = this.authorizationService.findByToken(
@ -120,16 +127,19 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide
String codeChallenge = (String) authorizationRequest.getAdditionalParameters() String codeChallenge = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE); .get(PkceParameterNames.CODE_CHALLENGE);
if (StringUtils.hasText(codeChallenge)) { if (!StringUtils.hasText(codeChallenge) &&
registeredClient.getClientSettings().requireProofKey()) {
throwInvalidClient();
}
String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters() String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD); .get(PkceParameterNames.CODE_CHALLENGE_METHOD);
String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER); String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER);
if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
throwInvalidClient(); throwInvalidClient();
} }
} else if (registeredClient.getClientSettings().requireProofKey()) {
throwInvalidClient(); return true;
}
} }
private static boolean authorizationCodeGrant(Map<String, Object> parameters) { private static boolean authorizationCodeGrant(Map<String, Object> parameters) {

View File

@ -37,7 +37,6 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/** /**
@ -119,6 +118,21 @@ public class OAuth2ClientAuthenticationProviderTests {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
} }
@Test
public void authenticateWhenClientSecretNotProvidedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);
OAuth2ClientAuthenticationToken authentication =
new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
}
@Test @Test
public void authenticateWhenValidCredentialsThenAuthenticated() { public void authenticateWhenValidCredentialsThenAuthenticated() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@ -135,21 +149,6 @@ public class OAuth2ClientAuthenticationProviderTests {
assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient);
} }
@Test
public void authenticateWhenNotPkceThenContinueAuthenticated() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);
OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret(), null);
OAuth2ClientAuthenticationToken authenticationResult =
(OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication);
assertThat(authenticationResult.isAuthenticated()).isTrue();
verifyNoInteractions(this.authorizationService);
}
@Test @Test
public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() { public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();