diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index f3dad82..f1c0abf 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -127,6 +127,20 @@ public class OAuth2Authorization implements Serializable { return new Builder(registeredClient.getId()); } + /** + * Returns a new {@link Builder}, initialized with the values from the provided {@code authorization}. + * + * @param authorization the authorization used for initializing the {@link Builder} + * @return the {@link Builder} + */ + public static Builder from(OAuth2Authorization authorization) { + Assert.notNull(authorization, "authorization cannot be null"); + return new Builder(authorization.getRegisteredClientId()) + .principalName(authorization.getPrincipalName()) + .accessToken(authorization.getAccessToken()) + .attributes(attrs -> attrs.putAll(authorization.getAttributes())); + } + /** * A builder for {@link OAuth2Authorization}. */ diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java index c9abd28..01b4eb3 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java @@ -20,35 +20,63 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; import java.util.Collections; /** + * An {@link Authentication} implementation used when issuing an OAuth 2.0 Access Token. + * * @author Joe Grandja * @author Madhu Bhat + * @since 0.0.1 + * @see AbstractAuthenticationToken + * @see OAuth2AuthorizationCodeAuthenticationProvider + * @see RegisteredClient + * @see OAuth2AccessToken + * @see OAuth2ClientAuthenticationToken */ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; - private RegisteredClient registeredClient; - private Authentication clientPrincipal; - private OAuth2AccessToken accessToken; + private final RegisteredClient registeredClient; + private final Authentication clientPrincipal; + private final OAuth2AccessToken accessToken; + /** + * Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters. + * + * @param registeredClient the registered client + * @param clientPrincipal the authenticated client principal + * @param accessToken the access token + */ public OAuth2AccessTokenAuthenticationToken(RegisteredClient registeredClient, Authentication clientPrincipal, OAuth2AccessToken accessToken) { super(Collections.emptyList()); + Assert.notNull(registeredClient, "registeredClient cannot be null"); + Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); + Assert.notNull(accessToken, "accessToken cannot be null"); this.registeredClient = registeredClient; this.clientPrincipal = clientPrincipal; this.accessToken = accessToken; } @Override - public Object getCredentials() { - return null; + public Object getPrincipal() { + return this.clientPrincipal; } @Override - public Object getPrincipal() { - return null; + public Object getCredentials() { + return ""; + } + + /** + * Returns the {@link RegisteredClient registered client}. + * + * @return the {@link RegisteredClient} + */ + public RegisteredClient getRegisteredClient() { + return this.registeredClient; } /** diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 0dd3382..0afd01b 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -18,21 +18,106 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; /** + * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant. + * * @author Joe Grandja + * @since 0.0.1 + * @see OAuth2AuthorizationCodeAuthenticationToken + * @see OAuth2AccessTokenAuthenticationToken + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + * @see Section 4.1 Authorization Code Grant + * @see Section 4.1.3 Access Token Request */ public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { - private RegisteredClientRepository registeredClientRepository; - private OAuth2AuthorizationService authorizationService; - private StringKeyGenerator accessTokenGenerator; + private final RegisteredClientRepository registeredClientRepository; + private final OAuth2AuthorizationService authorizationService; + private final StringKeyGenerator accessTokenGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + + /** + * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + */ + public OAuth2AuthorizationCodeAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.authorizationService = authorizationService; + } @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - return authentication; + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = + (OAuth2AuthorizationCodeAuthenticationToken) authentication; + + OAuth2ClientAuthenticationToken clientPrincipal = null; + if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) { + clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal(); + } + if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } + + // TODO Authenticate public client + // A client MAY use the "client_id" request parameter to identify itself + // when sending requests to the token endpoint. + // In the "authorization_code" "grant_type" request to the token endpoint, + // an unauthenticated client MUST send its "client_id" to prevent itself + // from inadvertently accepting a code intended for a client with a different "client_id". + // This protects the client from substitution of the authentication code. + + OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType( + authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE); + if (authorization == null) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + if (StringUtils.hasText(authorizationRequest.getRedirectUri()) && + !authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + + String tokenValue = this.accessTokenGenerator.generateKey(); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token lifespan + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + tokenValue, issuedAt, expiresAt, authorizationRequest.getScopes()); + + authorization = OAuth2Authorization.from(authorization) + .accessToken(accessToken) + .build(); + this.authorizationService.save(authorization); + + return new OAuth2AccessTokenAuthenticationToken( + clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken); } @Override diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java index 65e1609..28d2f2f 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java @@ -19,12 +19,19 @@ import org.springframework.lang.Nullable; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.server.authorization.Version; +import org.springframework.util.Assert; import java.util.Collections; /** + * An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant. + * * @author Joe Grandja * @author Madhu Bhat + * @since 0.0.1 + * @see AbstractAuthenticationToken + * @see OAuth2AuthorizationCodeAuthenticationProvider + * @see OAuth2ClientAuthenticationToken */ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; @@ -33,17 +40,35 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti private String clientId; private String redirectUri; + /** + * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters. + * + * @param code the authorization code + * @param clientPrincipal the authenticated client principal + * @param redirectUri the redirect uri + */ public OAuth2AuthorizationCodeAuthenticationToken(String code, Authentication clientPrincipal, @Nullable String redirectUri) { super(Collections.emptyList()); + Assert.hasText(code, "code cannot be empty"); + Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); this.code = code; this.clientPrincipal = clientPrincipal; this.redirectUri = redirectUri; } + /** + * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters. + * + * @param code the authorization code + * @param clientId the client identifier + * @param redirectUri the redirect uri + */ public OAuth2AuthorizationCodeAuthenticationToken(String code, String clientId, @Nullable String redirectUri) { super(Collections.emptyList()); + Assert.hasText(code, "code cannot be empty"); + Assert.hasText(clientId, "clientId cannot be empty"); this.code = code; this.clientId = clientId; this.redirectUri = redirectUri; @@ -60,20 +85,20 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti } /** - * Returns the code. + * Returns the authorization code. * - * @return the code + * @return the authorization code */ public String getCode() { return this.code; } /** - * Returns the redirectUri. + * Returns the redirect uri. * - * @return the redirectUri + * @return the redirect uri */ - public String getRedirectUri() { + public @Nullable String getRedirectUri() { return this.redirectUri; } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index 498c275..d3daef0 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -30,12 +30,13 @@ import static org.assertj.core.data.MapEntry.entry; * Tests for {@link OAuth2Authorization}. * * @author Krisztian Toth + * @author Joe Grandja */ public class OAuth2AuthorizationTests { private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final String PRINCIPAL_NAME = "principal"; private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now().minusSeconds(60), Instant.now()); + OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); private static final String AUTHORIZATION_CODE = "code"; @Test @@ -45,6 +46,28 @@ public class OAuth2AuthorizationTests { .hasMessage("registeredClient cannot be null"); } + @Test + public void fromWhenAuthorizationNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Authorization.from(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorization cannot be null"); + } + + @Test + public void fromWhenAuthorizationProvidedThenCopied() { + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .principalName(PRINCIPAL_NAME) + .accessToken(ACCESS_TOKEN) + .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) + .build(); + OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build(); + + assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); + assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); + assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken()); + assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes()); + } + @Test public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() { assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).build()) diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java new file mode 100644 index 0000000..7a17002 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -0,0 +1,46 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import java.time.Instant; + +/** + * @author Joe Grandja + */ +public class TestOAuth2Authorizations { + + public static OAuth2Authorization.Builder authorization() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://provider.com/oauth2/authorize") + .clientId(registeredClient.getClientId()) + .redirectUri("https://client.com/authorized") + .state("state") + .build(); + return OAuth2Authorization.withRegisteredClient(registeredClient) + .principalName("principal") + .accessToken(accessToken) + .attribute(OAuth2AuthorizationAttributeNames.CODE, "code") + .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationTokenTests.java new file mode 100644 index 0000000..ca237e7 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationTokenTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.junit.Test; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2AccessTokenAuthenticationToken}. + * + * @author Joe Grandja + */ +public class OAuth2AccessTokenAuthenticationTokenTests { + private RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + private OAuth2ClientAuthenticationToken clientPrincipal = + new OAuth2ClientAuthenticationToken(this.registeredClient); + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token", Instant.now(), Instant.now().plusSeconds(300)); + + @Test + public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(null, this.clientPrincipal, this.accessToken)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClient cannot be null"); + } + + @Test + public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, null, this.accessToken)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientPrincipal cannot be null"); + } + + @Test + public void constructorWhenAccessTokenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, this.clientPrincipal, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessToken cannot be null"); + } + + @Test + public void constructorWhenAllValuesProvidedThenCreated() { + OAuth2AccessTokenAuthenticationToken authentication = new OAuth2AccessTokenAuthenticationToken( + this.registeredClient, this.clientPrincipal, this.accessToken); + assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getRegisteredClient()).isEqualTo(this.registeredClient); + assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java new file mode 100644 index 0000000..8187353 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizationCodeAuthenticationProviderTests { + private RegisteredClient registeredClient; + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; + + @Before + public void setUp() { + this.registeredClient = TestRegisteredClients.registeredClient().build(); + this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( + this.registeredClientRepository, this.authorizationService); + } + + @Test + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null, this.authorizationService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClientRepository cannot be null"); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.registeredClientRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); + } + + @Test + public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenCodeIssuedToAnotherClientThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + TestRegisteredClients.registeredClient2().build()); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid"); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenValidCodeThenReturnAccessToken() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri()); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); + assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(updatedAuthorization.getAccessToken()).isNotNull(); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java new file mode 100644 index 0000000..e2977a3 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.junit.Test; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizationCodeAuthenticationTokenTests { + private String code = "code"; + private OAuth2ClientAuthenticationToken clientPrincipal = + new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); + private String clientId = "clientId"; + private String redirectUri = "redirectUri"; + + @Test + public void constructorWhenCodeNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("code cannot be empty"); + } + + @Test + public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientPrincipal cannot be null"); + } + + @Test + public void constructorWhenClientIdNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientId cannot be empty"); + } + + @Test + public void constructorWhenClientPrincipalProvidedThenCreated() { + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + this.code, this.clientPrincipal, this.redirectUri); + assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getCode()).isEqualTo(this.code); + assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri); + } + + @Test + public void constructorWhenClientIdProvidedThenCreated() { + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + this.code, this.clientId, this.redirectUri); + assertThat(authentication.getPrincipal()).isEqualTo(this.clientId); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getCode()).isEqualTo(this.code); + assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri); + } +}