/* * Copyright 2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.oauth2.server.authorization.authentication; import java.util.HashMap; import java.util.Map; import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** * Tests for {@link OAuth2ClientAuthenticationProvider}. * * @author Patryk Kostrzewa * @author Joe Grandja * @author Daniel Garnier-Moiroux * @author Anoop Garlapati */ public class OAuth2ClientAuthenticationProviderTests { private static final String PLAIN_CODE_VERIFIER = "pkce-key"; private static final String PLAIN_CODE_CHALLENGE = PLAIN_CODE_VERIFIER; // See RFC 7636: Appendix B. Example for the S256 code_challenge_method // https://tools.ietf.org/html/rfc7636#appendix-B private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; private static final String AUTHORIZATION_CODE = "code"; private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private OAuth2ClientAuthenticationProvider authenticationProvider; @Before public void setUp() { this.registeredClientRepository = mock(RegisteredClientRepository.class); this.authorizationService = mock(OAuth2AuthorizationService.class); this.authenticationProvider = new OAuth2ClientAuthenticationProvider( this.registeredClientRepository, this.authorizationService); } @Test public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2ClientAuthenticationProvider(null, this.authorizationService)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("registeredClientRepository cannot be null"); } @Test public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2ClientAuthenticationProvider(this.registeredClientRepository, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizationService cannot be null"); } @Test public void supportsWhenTypeOAuth2ClientAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OAuth2ClientAuthenticationToken.class)).isTrue(); } @Test public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( registeredClient.getClientId() + "-invalid", registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenInvalidClientSecretThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( registeredClient.getClientId(), registeredClient.getClientSecret() + "-invalid", ClientAuthenticationMethod.BASIC, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenClientSecretNotProvidedThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenValidCredentialsThenAuthenticated() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null); OAuth2ClientAuthenticationToken authenticationResult = (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); assertThat(authenticationResult.getCredentials()).isNull(); assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); } @Test public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, createPkceAuthorizationParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); parameters.put(OAuth2ParameterNames.CODE, "invalid-code"); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenPkceAndRequireProofKeyAndMissingCodeChallengeThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientSettings(clientSettings -> clientSettings.requireProofKey(true)) .build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenPkceAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, createPkceAuthorizationParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); parameters.remove(PkceParameterNames.CODE_VERIFIER); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenPkceAndPlainMethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, createPkceAuthorizationParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters("invalid-code-verifier"); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenPkceAndS256MethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, createPkceAuthorizationParametersS256()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters("invalid-code-verifier"); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test public void authenticateWhenPkceAndPlainMethodAndValidCodeVerifierThenAuthenticated() { RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, createPkceAuthorizationParametersPlain()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); OAuth2ClientAuthenticationToken authenticationResult = (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); assertThat(authenticationResult.getCredentials()).isNull(); assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); } @Test public void authenticateWhenPkceAndMissingMethodThenDefaultPlainMethodAndAuthenticated() { RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); Map authorizationRequestAdditionalParameters = createPkceAuthorizationParametersPlain(); authorizationRequestAdditionalParameters.remove(PkceParameterNames.CODE_CHALLENGE_METHOD); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, authorizationRequestAdditionalParameters) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); OAuth2ClientAuthenticationToken authenticationResult = (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); assertThat(authenticationResult.getCredentials()).isNull(); assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); } @Test public void authenticateWhenPkceAndS256MethodAndValidCodeVerifierThenAuthenticated() { RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, createPkceAuthorizationParametersS256()) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters(S256_CODE_VERIFIER); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); OAuth2ClientAuthenticationToken authenticationResult = (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); assertThat(authenticationResult.getCredentials()).isNull(); assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); } @Test public void authenticateWhenPkceAndUnsupportedCodeChallengeMethodThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); Map authorizationRequestAdditionalParameters = createPkceAuthorizationParametersPlain(); // This should never happen: the Authorization endpoint should not allow it authorizationRequestAdditionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method"); OAuth2Authorization authorization = TestOAuth2Authorizations .authorization(registeredClient, authorizationRequestAdditionalParameters) .build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), parameters); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); } @Test public void authenticateWhenClientAuthenticationMethodNotConfiguredThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.POST, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } private static Map createPkceTokenParameters(String codeVerifier) { Map parameters = new HashMap<>(); parameters.put(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); parameters.put(OAuth2ParameterNames.CODE, AUTHORIZATION_CODE); parameters.put(PkceParameterNames.CODE_VERIFIER, codeVerifier); return parameters; } private static Map createPkceAuthorizationParametersPlain() { Map parameters = new HashMap<>(); parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); parameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); return parameters; } private static Map createPkceAuthorizationParametersS256() { Map parameters = new HashMap<>(); parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); parameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE); return parameters; } }