diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
index f3dad82..f1c0abf 100644
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
+++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
@@ -127,6 +127,20 @@ public class OAuth2Authorization implements Serializable {
return new Builder(registeredClient.getId());
}
+ /**
+ * Returns a new {@link Builder}, initialized with the values from the provided {@code authorization}.
+ *
+ * @param authorization the authorization used for initializing the {@link Builder}
+ * @return the {@link Builder}
+ */
+ public static Builder from(OAuth2Authorization authorization) {
+ Assert.notNull(authorization, "authorization cannot be null");
+ return new Builder(authorization.getRegisteredClientId())
+ .principalName(authorization.getPrincipalName())
+ .accessToken(authorization.getAccessToken())
+ .attributes(attrs -> attrs.putAll(authorization.getAttributes()));
+ }
+
/**
* A builder for {@link OAuth2Authorization}.
*/
diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java
index c9abd28..01b4eb3 100644
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java
+++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java
@@ -20,35 +20,63 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
import java.util.Collections;
/**
+ * An {@link Authentication} implementation used when issuing an OAuth 2.0 Access Token.
+ *
* @author Joe Grandja
* @author Madhu Bhat
+ * @since 0.0.1
+ * @see AbstractAuthenticationToken
+ * @see OAuth2AuthorizationCodeAuthenticationProvider
+ * @see RegisteredClient
+ * @see OAuth2AccessToken
+ * @see OAuth2ClientAuthenticationToken
*/
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
- private RegisteredClient registeredClient;
- private Authentication clientPrincipal;
- private OAuth2AccessToken accessToken;
+ private final RegisteredClient registeredClient;
+ private final Authentication clientPrincipal;
+ private final OAuth2AccessToken accessToken;
+ /**
+ * Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
+ *
+ * @param registeredClient the registered client
+ * @param clientPrincipal the authenticated client principal
+ * @param accessToken the access token
+ */
public OAuth2AccessTokenAuthenticationToken(RegisteredClient registeredClient,
Authentication clientPrincipal, OAuth2AccessToken accessToken) {
super(Collections.emptyList());
+ Assert.notNull(registeredClient, "registeredClient cannot be null");
+ Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
+ Assert.notNull(accessToken, "accessToken cannot be null");
this.registeredClient = registeredClient;
this.clientPrincipal = clientPrincipal;
this.accessToken = accessToken;
}
@Override
- public Object getCredentials() {
- return null;
+ public Object getPrincipal() {
+ return this.clientPrincipal;
}
@Override
- public Object getPrincipal() {
- return null;
+ public Object getCredentials() {
+ return "";
+ }
+
+ /**
+ * Returns the {@link RegisteredClient registered client}.
+ *
+ * @return the {@link RegisteredClient}
+ */
+ public RegisteredClient getRegisteredClient() {
+ return this.registeredClient;
}
/**
diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
index 0dd3382..0afd01b 100644
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
+++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
@@ -18,21 +18,106 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Base64;
/**
+ * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant.
+ *
* @author Joe Grandja
+ * @since 0.0.1
+ * @see OAuth2AuthorizationCodeAuthenticationToken
+ * @see OAuth2AccessTokenAuthenticationToken
+ * @see RegisteredClientRepository
+ * @see OAuth2AuthorizationService
+ * @see Section 4.1 Authorization Code Grant
+ * @see Section 4.1.3 Access Token Request
*/
public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
- private RegisteredClientRepository registeredClientRepository;
- private OAuth2AuthorizationService authorizationService;
- private StringKeyGenerator accessTokenGenerator;
+ private final RegisteredClientRepository registeredClientRepository;
+ private final OAuth2AuthorizationService authorizationService;
+ private final StringKeyGenerator accessTokenGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
+
+ /**
+ * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
+ *
+ * @param registeredClientRepository the repository of registered clients
+ * @param authorizationService the authorization service
+ */
+ public OAuth2AuthorizationCodeAuthenticationProvider(RegisteredClientRepository registeredClientRepository,
+ OAuth2AuthorizationService authorizationService) {
+ Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+ Assert.notNull(authorizationService, "authorizationService cannot be null");
+ this.registeredClientRepository = registeredClientRepository;
+ this.authorizationService = authorizationService;
+ }
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
- return authentication;
+ OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
+ (OAuth2AuthorizationCodeAuthenticationToken) authentication;
+
+ OAuth2ClientAuthenticationToken clientPrincipal = null;
+ if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
+ clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
+ }
+ if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
+ throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
+ }
+
+ // TODO Authenticate public client
+ // A client MAY use the "client_id" request parameter to identify itself
+ // when sending requests to the token endpoint.
+ // In the "authorization_code" "grant_type" request to the token endpoint,
+ // an unauthenticated client MUST send its "client_id" to prevent itself
+ // from inadvertently accepting a code intended for a client with a different "client_id".
+ // This protects the client from substitution of the authentication code.
+
+ OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType(
+ authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
+ if (authorization == null) {
+ throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+ }
+ if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
+ throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+ }
+
+ OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+ OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+ if (StringUtils.hasText(authorizationRequest.getRedirectUri()) &&
+ !authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) {
+ throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+ }
+
+ String tokenValue = this.accessTokenGenerator.generateKey();
+ Instant issuedAt = Instant.now();
+ Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token lifespan
+ OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+ tokenValue, issuedAt, expiresAt, authorizationRequest.getScopes());
+
+ authorization = OAuth2Authorization.from(authorization)
+ .accessToken(accessToken)
+ .build();
+ this.authorizationService.save(authorization);
+
+ return new OAuth2AccessTokenAuthenticationToken(
+ clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
}
@Override
diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java
index 65e1609..28d2f2f 100644
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java
+++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java
@@ -19,12 +19,19 @@ import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.Version;
+import org.springframework.util.Assert;
import java.util.Collections;
/**
+ * An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
+ *
* @author Joe Grandja
* @author Madhu Bhat
+ * @since 0.0.1
+ * @see AbstractAuthenticationToken
+ * @see OAuth2AuthorizationCodeAuthenticationProvider
+ * @see OAuth2ClientAuthenticationToken
*/
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
@@ -33,17 +40,35 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
private String clientId;
private String redirectUri;
+ /**
+ * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
+ *
+ * @param code the authorization code
+ * @param clientPrincipal the authenticated client principal
+ * @param redirectUri the redirect uri
+ */
public OAuth2AuthorizationCodeAuthenticationToken(String code,
Authentication clientPrincipal, @Nullable String redirectUri) {
super(Collections.emptyList());
+ Assert.hasText(code, "code cannot be empty");
+ Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
this.code = code;
this.clientPrincipal = clientPrincipal;
this.redirectUri = redirectUri;
}
+ /**
+ * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
+ *
+ * @param code the authorization code
+ * @param clientId the client identifier
+ * @param redirectUri the redirect uri
+ */
public OAuth2AuthorizationCodeAuthenticationToken(String code,
String clientId, @Nullable String redirectUri) {
super(Collections.emptyList());
+ Assert.hasText(code, "code cannot be empty");
+ Assert.hasText(clientId, "clientId cannot be empty");
this.code = code;
this.clientId = clientId;
this.redirectUri = redirectUri;
@@ -60,20 +85,20 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
}
/**
- * Returns the code.
+ * Returns the authorization code.
*
- * @return the code
+ * @return the authorization code
*/
public String getCode() {
return this.code;
}
/**
- * Returns the redirectUri.
+ * Returns the redirect uri.
*
- * @return the redirectUri
+ * @return the redirect uri
*/
- public String getRedirectUri() {
+ public @Nullable String getRedirectUri() {
return this.redirectUri;
}
}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
index 498c275..d3daef0 100644
--- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
@@ -30,12 +30,13 @@ import static org.assertj.core.data.MapEntry.entry;
* Tests for {@link OAuth2Authorization}.
*
* @author Krisztian Toth
+ * @author Joe Grandja
*/
public class OAuth2AuthorizationTests {
private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
private static final String PRINCIPAL_NAME = "principal";
private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
- OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now().minusSeconds(60), Instant.now());
+ OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
private static final String AUTHORIZATION_CODE = "code";
@Test
@@ -45,6 +46,28 @@ public class OAuth2AuthorizationTests {
.hasMessage("registeredClient cannot be null");
}
+ @Test
+ public void fromWhenAuthorizationNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> OAuth2Authorization.from(null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("authorization cannot be null");
+ }
+
+ @Test
+ public void fromWhenAuthorizationProvidedThenCopied() {
+ OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+ .principalName(PRINCIPAL_NAME)
+ .accessToken(ACCESS_TOKEN)
+ .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
+ .build();
+ OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
+
+ assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
+ assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
+ assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
+ assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
+ }
+
@Test
public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).build())
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
new file mode 100644
index 0000000..7a17002
--- /dev/null
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import java.time.Instant;
+
+/**
+ * @author Joe Grandja
+ */
+public class TestOAuth2Authorizations {
+
+ public static OAuth2Authorization.Builder authorization() {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ OAuth2AccessToken accessToken = new OAuth2AccessToken(
+ OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
+ OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+ .authorizationUri("https://provider.com/oauth2/authorize")
+ .clientId(registeredClient.getClientId())
+ .redirectUri("https://client.com/authorized")
+ .state("state")
+ .build();
+ return OAuth2Authorization.withRegisteredClient(registeredClient)
+ .principalName("principal")
+ .accessToken(accessToken)
+ .attribute(OAuth2AuthorizationAttributeNames.CODE, "code")
+ .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest);
+ }
+}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationTokenTests.java
new file mode 100644
index 0000000..ca237e7
--- /dev/null
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationTokenTests.java
@@ -0,0 +1,70 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.junit.Test;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import java.time.Instant;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2AccessTokenAuthenticationToken}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AccessTokenAuthenticationTokenTests {
+ private RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ private OAuth2ClientAuthenticationToken clientPrincipal =
+ new OAuth2ClientAuthenticationToken(this.registeredClient);
+ private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+ "access-token", Instant.now(), Instant.now().plusSeconds(300));
+
+ @Test
+ public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(null, this.clientPrincipal, this.accessToken))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("registeredClient cannot be null");
+ }
+
+ @Test
+ public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, null, this.accessToken))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("clientPrincipal cannot be null");
+ }
+
+ @Test
+ public void constructorWhenAccessTokenNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, this.clientPrincipal, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("accessToken cannot be null");
+ }
+
+ @Test
+ public void constructorWhenAllValuesProvidedThenCreated() {
+ OAuth2AccessTokenAuthenticationToken authentication = new OAuth2AccessTokenAuthenticationToken(
+ this.registeredClient, this.clientPrincipal, this.accessToken);
+ assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
+ assertThat(authentication.getCredentials().toString()).isEmpty();
+ assertThat(authentication.getRegisteredClient()).isEqualTo(this.registeredClient);
+ assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
+ }
+}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
new file mode 100644
index 0000000..8187353
--- /dev/null
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AuthorizationCodeAuthenticationProviderTests {
+ private RegisteredClient registeredClient;
+ private RegisteredClientRepository registeredClientRepository;
+ private OAuth2AuthorizationService authorizationService;
+ private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider;
+
+ @Before
+ public void setUp() {
+ this.registeredClient = TestRegisteredClients.registeredClient().build();
+ this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient);
+ this.authorizationService = mock(OAuth2AuthorizationService.class);
+ this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
+ this.registeredClientRepository, this.authorizationService);
+ }
+
+ @Test
+ public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null, this.authorizationService))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("registeredClientRepository cannot be null");
+ }
+
+ @Test
+ public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.registeredClientRepository, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("authorizationService cannot be null");
+ }
+
+ @Test
+ public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() {
+ assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue();
+ }
+
+ @Test
+ public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() {
+ TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
+ this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
+ OAuth2AuthorizationCodeAuthenticationToken authentication =
+ new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+ assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+ .isInstanceOf(OAuth2AuthenticationException.class)
+ .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+ .extracting("errorCode")
+ .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+ }
+
+ @Test
+ public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() {
+ OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+ this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
+ OAuth2AuthorizationCodeAuthenticationToken authentication =
+ new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+ assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+ .isInstanceOf(OAuth2AuthenticationException.class)
+ .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+ .extracting("errorCode")
+ .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+ }
+
+ @Test
+ public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() {
+ OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+ OAuth2AuthorizationCodeAuthenticationToken authentication =
+ new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+ assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+ .isInstanceOf(OAuth2AuthenticationException.class)
+ .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+ .extracting("errorCode")
+ .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+ }
+
+ @Test
+ public void authenticateWhenCodeIssuedToAnotherClientThenThrowOAuth2AuthenticationException() {
+ OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+ when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
+ .thenReturn(authorization);
+
+ OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+ TestRegisteredClients.registeredClient2().build());
+ OAuth2AuthorizationCodeAuthenticationToken authentication =
+ new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+ assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+ .isInstanceOf(OAuth2AuthenticationException.class)
+ .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+ .extracting("errorCode")
+ .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+ }
+
+ @Test
+ public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() {
+ OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+ when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
+ .thenReturn(authorization);
+
+ OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+ OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+ OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+ OAuth2AuthorizationCodeAuthenticationToken authentication =
+ new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid");
+ assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+ .isInstanceOf(OAuth2AuthenticationException.class)
+ .extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+ .extracting("errorCode")
+ .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+ }
+
+ @Test
+ public void authenticateWhenValidCodeThenReturnAccessToken() {
+ OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+ when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
+ .thenReturn(authorization);
+
+ OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+ OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+ OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+ OAuth2AuthorizationCodeAuthenticationToken authentication =
+ new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri());
+
+ OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+ (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+ ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+ verify(this.authorizationService).save(authorizationCaptor.capture());
+ OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+ assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
+ assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
+ assertThat(updatedAuthorization.getAccessToken()).isNotNull();
+ assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
+ }
+}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java
new file mode 100644
index 0000000..e2977a3
--- /dev/null
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.junit.Test;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AuthorizationCodeAuthenticationTokenTests {
+ private String code = "code";
+ private OAuth2ClientAuthenticationToken clientPrincipal =
+ new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
+ private String clientId = "clientId";
+ private String redirectUri = "redirectUri";
+
+ @Test
+ public void constructorWhenCodeNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("code cannot be empty");
+ }
+
+ @Test
+ public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("clientPrincipal cannot be null");
+ }
+
+ @Test
+ public void constructorWhenClientIdNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("clientId cannot be empty");
+ }
+
+ @Test
+ public void constructorWhenClientPrincipalProvidedThenCreated() {
+ OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
+ this.code, this.clientPrincipal, this.redirectUri);
+ assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
+ assertThat(authentication.getCredentials().toString()).isEmpty();
+ assertThat(authentication.getCode()).isEqualTo(this.code);
+ assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
+ }
+
+ @Test
+ public void constructorWhenClientIdProvidedThenCreated() {
+ OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
+ this.code, this.clientId, this.redirectUri);
+ assertThat(authentication.getPrincipal()).isEqualTo(this.clientId);
+ assertThat(authentication.getCredentials().toString()).isEmpty();
+ assertThat(authentication.getCode()).isEqualTo(this.code);
+ assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
+ }
+}