Add authorization_code AuthenticationProvider

Fixes gh-68
This commit is contained in:
Joe Grandja 2020-06-03 16:08:52 -04:00 committed by Joe Grandja
parent 7a5585f1fc
commit fe286b5994
9 changed files with 563 additions and 17 deletions

View File

@ -127,6 +127,20 @@ public class OAuth2Authorization implements Serializable {
return new Builder(registeredClient.getId()); return new Builder(registeredClient.getId());
} }
/**
* Returns a new {@link Builder}, initialized with the values from the provided {@code authorization}.
*
* @param authorization the authorization used for initializing the {@link Builder}
* @return the {@link Builder}
*/
public static Builder from(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
return new Builder(authorization.getRegisteredClientId())
.principalName(authorization.getPrincipalName())
.accessToken(authorization.getAccessToken())
.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
}
/** /**
* A builder for {@link OAuth2Authorization}. * A builder for {@link OAuth2Authorization}.
*/ */

View File

@ -20,35 +20,63 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
import java.util.Collections; import java.util.Collections;
/** /**
* An {@link Authentication} implementation used when issuing an OAuth 2.0 Access Token.
*
* @author Joe Grandja * @author Joe Grandja
* @author Madhu Bhat * @author Madhu Bhat
* @since 0.0.1
* @see AbstractAuthenticationToken
* @see OAuth2AuthorizationCodeAuthenticationProvider
* @see RegisteredClient
* @see OAuth2AccessToken
* @see OAuth2ClientAuthenticationToken
*/ */
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken { public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private RegisteredClient registeredClient; private final RegisteredClient registeredClient;
private Authentication clientPrincipal; private final Authentication clientPrincipal;
private OAuth2AccessToken accessToken; private final OAuth2AccessToken accessToken;
/**
* Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
*
* @param registeredClient the registered client
* @param clientPrincipal the authenticated client principal
* @param accessToken the access token
*/
public OAuth2AccessTokenAuthenticationToken(RegisteredClient registeredClient, public OAuth2AccessTokenAuthenticationToken(RegisteredClient registeredClient,
Authentication clientPrincipal, OAuth2AccessToken accessToken) { Authentication clientPrincipal, OAuth2AccessToken accessToken) {
super(Collections.emptyList()); super(Collections.emptyList());
Assert.notNull(registeredClient, "registeredClient cannot be null");
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
Assert.notNull(accessToken, "accessToken cannot be null");
this.registeredClient = registeredClient; this.registeredClient = registeredClient;
this.clientPrincipal = clientPrincipal; this.clientPrincipal = clientPrincipal;
this.accessToken = accessToken; this.accessToken = accessToken;
} }
@Override @Override
public Object getCredentials() { public Object getPrincipal() {
return null; return this.clientPrincipal;
} }
@Override @Override
public Object getPrincipal() { public Object getCredentials() {
return null; return "";
}
/**
* Returns the {@link RegisteredClient registered client}.
*
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return this.registeredClient;
} }
/** /**

View File

@ -18,21 +18,106 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
/** /**
* An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant.
*
* @author Joe Grandja * @author Joe Grandja
* @since 0.0.1
* @see OAuth2AuthorizationCodeAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken
* @see RegisteredClientRepository
* @see OAuth2AuthorizationService
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
*/ */
public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
private RegisteredClientRepository registeredClientRepository; private final RegisteredClientRepository registeredClientRepository;
private OAuth2AuthorizationService authorizationService; private final OAuth2AuthorizationService authorizationService;
private StringKeyGenerator accessTokenGenerator; private final StringKeyGenerator accessTokenGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
*
* @param registeredClientRepository the repository of registered clients
* @param authorizationService the authorization service
*/
public OAuth2AuthorizationCodeAuthenticationProvider(RegisteredClientRepository registeredClientRepository,
OAuth2AuthorizationService authorizationService) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
Assert.notNull(authorizationService, "authorizationService cannot be null");
this.registeredClientRepository = registeredClientRepository;
this.authorizationService = authorizationService;
}
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
return authentication; OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
(OAuth2AuthorizationCodeAuthenticationToken) authentication;
OAuth2ClientAuthenticationToken clientPrincipal = null;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
}
if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}
// TODO Authenticate public client
// A client MAY use the "client_id" request parameter to identify itself
// when sending requests to the token endpoint.
// In the "authorization_code" "grant_type" request to the token endpoint,
// an unauthenticated client MUST send its "client_id" to prevent itself
// from inadvertently accepting a code intended for a client with a different "client_id".
// This protects the client from substitution of the authentication code.
OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType(
authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
if (authorization == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
if (StringUtils.hasText(authorizationRequest.getRedirectUri()) &&
!authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}
String tokenValue = this.accessTokenGenerator.generateKey();
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token lifespan
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
tokenValue, issuedAt, expiresAt, authorizationRequest.getScopes());
authorization = OAuth2Authorization.from(authorization)
.accessToken(accessToken)
.build();
this.authorizationService.save(authorization);
return new OAuth2AccessTokenAuthenticationToken(
clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
} }
@Override @Override

View File

@ -19,12 +19,19 @@ import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.util.Assert;
import java.util.Collections; import java.util.Collections;
/** /**
* An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
*
* @author Joe Grandja * @author Joe Grandja
* @author Madhu Bhat * @author Madhu Bhat
* @since 0.0.1
* @see AbstractAuthenticationToken
* @see OAuth2AuthorizationCodeAuthenticationProvider
* @see OAuth2ClientAuthenticationToken
*/ */
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
@ -33,17 +40,35 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
private String clientId; private String clientId;
private String redirectUri; private String redirectUri;
/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
*
* @param code the authorization code
* @param clientPrincipal the authenticated client principal
* @param redirectUri the redirect uri
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code, public OAuth2AuthorizationCodeAuthenticationToken(String code,
Authentication clientPrincipal, @Nullable String redirectUri) { Authentication clientPrincipal, @Nullable String redirectUri) {
super(Collections.emptyList()); super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
this.code = code; this.code = code;
this.clientPrincipal = clientPrincipal; this.clientPrincipal = clientPrincipal;
this.redirectUri = redirectUri; this.redirectUri = redirectUri;
} }
/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
*
* @param code the authorization code
* @param clientId the client identifier
* @param redirectUri the redirect uri
*/
public OAuth2AuthorizationCodeAuthenticationToken(String code, public OAuth2AuthorizationCodeAuthenticationToken(String code,
String clientId, @Nullable String redirectUri) { String clientId, @Nullable String redirectUri) {
super(Collections.emptyList()); super(Collections.emptyList());
Assert.hasText(code, "code cannot be empty");
Assert.hasText(clientId, "clientId cannot be empty");
this.code = code; this.code = code;
this.clientId = clientId; this.clientId = clientId;
this.redirectUri = redirectUri; this.redirectUri = redirectUri;
@ -60,20 +85,20 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
} }
/** /**
* Returns the code. * Returns the authorization code.
* *
* @return the code * @return the authorization code
*/ */
public String getCode() { public String getCode() {
return this.code; return this.code;
} }
/** /**
* Returns the redirectUri. * Returns the redirect uri.
* *
* @return the redirectUri * @return the redirect uri
*/ */
public String getRedirectUri() { public @Nullable String getRedirectUri() {
return this.redirectUri; return this.redirectUri;
} }
} }

View File

@ -30,12 +30,13 @@ import static org.assertj.core.data.MapEntry.entry;
* Tests for {@link OAuth2Authorization}. * Tests for {@link OAuth2Authorization}.
* *
* @author Krisztian Toth * @author Krisztian Toth
* @author Joe Grandja
*/ */
public class OAuth2AuthorizationTests { public class OAuth2AuthorizationTests {
private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
private static final String PRINCIPAL_NAME = "principal"; private static final String PRINCIPAL_NAME = "principal";
private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken( private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now().minusSeconds(60), Instant.now()); OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
private static final String AUTHORIZATION_CODE = "code"; private static final String AUTHORIZATION_CODE = "code";
@Test @Test
@ -45,6 +46,28 @@ public class OAuth2AuthorizationTests {
.hasMessage("registeredClient cannot be null"); .hasMessage("registeredClient cannot be null");
} }
@Test
public void fromWhenAuthorizationNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Authorization.from(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorization cannot be null");
}
@Test
public void fromWhenAuthorizationProvidedThenCopied() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.accessToken(ACCESS_TOKEN)
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
.build();
OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
}
@Test @Test
public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() { public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).build()) assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).build())

View File

@ -0,0 +1,46 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import java.time.Instant;
/**
* @author Joe Grandja
*/
public class TestOAuth2Authorizations {
public static OAuth2Authorization.Builder authorization() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://provider.com/oauth2/authorize")
.clientId(registeredClient.getClientId())
.redirectUri("https://client.com/authorized")
.state("state")
.build();
return OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName("principal")
.accessToken(accessToken)
.attribute(OAuth2AuthorizationAttributeNames.CODE, "code")
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest);
}
}

View File

@ -0,0 +1,70 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import java.time.Instant;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link OAuth2AccessTokenAuthenticationToken}.
*
* @author Joe Grandja
*/
public class OAuth2AccessTokenAuthenticationTokenTests {
private RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
private OAuth2ClientAuthenticationToken clientPrincipal =
new OAuth2ClientAuthenticationToken(this.registeredClient);
private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"access-token", Instant.now(), Instant.now().plusSeconds(300));
@Test
public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(null, this.clientPrincipal, this.accessToken))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("registeredClient cannot be null");
}
@Test
public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, null, this.accessToken))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientPrincipal cannot be null");
}
@Test
public void constructorWhenAccessTokenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, this.clientPrincipal, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("accessToken cannot be null");
}
@Test
public void constructorWhenAllValuesProvidedThenCreated() {
OAuth2AccessTokenAuthenticationToken authentication = new OAuth2AccessTokenAuthenticationToken(
this.registeredClient, this.clientPrincipal, this.accessToken);
assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getRegisteredClient()).isEqualTo(this.registeredClient);
assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
}
}

View File

@ -0,0 +1,178 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}.
*
* @author Joe Grandja
*/
public class OAuth2AuthorizationCodeAuthenticationProviderTests {
private RegisteredClient registeredClient;
private RegisteredClientRepository registeredClientRepository;
private OAuth2AuthorizationService authorizationService;
private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider;
@Before
public void setUp() {
this.registeredClient = TestRegisteredClients.registeredClient().build();
this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient);
this.authorizationService = mock(OAuth2AuthorizationService.class);
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
this.registeredClientRepository, this.authorizationService);
}
@Test
public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null, this.authorizationService))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("registeredClientRepository cannot be null");
}
@Test
public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.registeredClientRepository, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationService cannot be null");
}
@Test
public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() {
assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue();
}
@Test
public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() {
TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
}
@Test
public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
}
@Test
public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenCodeIssuedToAnotherClientThenThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
TestRegisteredClients.registeredClient2().build());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid");
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
}
@Test
public void authenticateWhenValidCodeThenReturnAccessToken() {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri());
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(updatedAuthorization.getAccessToken()).isNotNull();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import org.junit.Test;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}.
*
* @author Joe Grandja
*/
public class OAuth2AuthorizationCodeAuthenticationTokenTests {
private String code = "code";
private OAuth2ClientAuthenticationToken clientPrincipal =
new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
private String clientId = "clientId";
private String redirectUri = "redirectUri";
@Test
public void constructorWhenCodeNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("code cannot be empty");
}
@Test
public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientPrincipal cannot be null");
}
@Test
public void constructorWhenClientIdNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientId cannot be empty");
}
@Test
public void constructorWhenClientPrincipalProvidedThenCreated() {
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
this.code, this.clientPrincipal, this.redirectUri);
assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getCode()).isEqualTo(this.code);
assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
}
@Test
public void constructorWhenClientIdProvidedThenCreated() {
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
this.code, this.clientId, this.redirectUri);
assertThat(authentication.getPrincipal()).isEqualTo(this.clientId);
assertThat(authentication.getCredentials().toString()).isEmpty();
assertThat(authentication.getCode()).isEqualTo(this.code);
assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
}
}