diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java index ee294a6..93cf659 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -91,6 +92,10 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } + if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)); + } + Instant refreshTokenExpiresAt = authorization.getTokens().getRefreshToken().getExpiresAt(); if (refreshTokenExpiresAt.isBefore(Instant.now())) { // As per https://tools.ietf.org/html/rfc6749#section-5.2 @@ -125,7 +130,7 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP } authorization = OAuth2Authorization.from(authorization) - .tokens(OAuth2Tokens.builder().accessToken(accessToken).refreshToken(refreshToken).build()) + .tokens(OAuth2Tokens.from(authorization.getTokens()).accessToken(accessToken).refreshToken(refreshToken).build()) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index e212b1d..d3e0f2e 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -144,7 +144,7 @@ public class OAuth2AuthorizationCodeGrantTests { } @Test - public void requestWhenTokenRequestValidThenResponseIncludesCacheHeaders() throws Exception { + public void requestWhenTokenRequestValidThenReturnAccessTokenResponse() throws Exception { this.spring.register(AuthorizationServerConfiguration.class).autowire(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -157,13 +157,18 @@ public class OAuth2AuthorizationCodeGrantTests { eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) .params(getTokenRequestParameters(registeredClient, authorization)) .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( registeredClient.getClientId(), registeredClient.getClientSecret()))) .andExpect(status().isOk()) .andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store"))) - .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))); + .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))) + .andExpect(jsonPath("$.access_token").isNotEmpty()) + .andExpect(jsonPath("$.token_type").isNotEmpty()) + .andExpect(jsonPath("$.expires_in").isNotEmpty()) + .andExpect(jsonPath("$.refresh_token").isNotEmpty()) + .andExpect(jsonPath("$.scope").isNotEmpty()); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService).findByToken( @@ -173,12 +178,13 @@ public class OAuth2AuthorizationCodeGrantTests { } @Test - public void requestWhenPublicClientWithPkceThenReturnAccessToken() throws Exception { + public void requestWhenPublicClientWithPkceThenReturnAccessTokenResponse() throws Exception { this.spring.register(AuthorizationServerConfiguration.class).autowire(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientSecret(null) .clientSettings(clientSettings -> clientSettings.requireProofKey(true)) + .tokenSettings(tokenSettings -> tokenSettings.enableRefreshTokens(false)) .build(); when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); @@ -208,7 +214,11 @@ public class OAuth2AuthorizationCodeGrantTests { .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)) .andExpect(status().isOk()) - .andExpect(jsonPath("$.access_token").isNotEmpty()); + .andExpect(jsonPath("$.access_token").isNotEmpty()) + .andExpect(jsonPath("$.token_type").isNotEmpty()) + .andExpect(jsonPath("$.expires_in").isNotEmpty()) + .andExpect(jsonPath("$.refresh_token").doesNotExist()) + .andExpect(jsonPath("$.scope").isNotEmpty()); verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService, times(2)).findByToken( @@ -217,46 +227,6 @@ public class OAuth2AuthorizationCodeGrantTests { verify(authorizationService, times(2)).save(any()); } - @Test - public void requestWhenPublicClientWithRefreshThenReturnRefreshToken() throws Exception { - this.spring.register(AuthorizationServerConfiguration.class).autowire(); - - RegisteredClient registeredClient = TestRegisteredClients - .registeredClient() - .clientSecret(null) - .tokenSettings(tokenSettings -> tokenSettings.enableRefreshTokens(true)) - .build(); - when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) - .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) - .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") - .with(user("user"))) - .andExpect(status().is3xxRedirection()) - .andReturn(); - assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); - - verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); - - ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); - verify(authorizationService).save(authorizationCaptor.capture()); - OAuth2Authorization authorization = authorizationCaptor.getValue(); - - when(authorizationService.findByToken( - eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), - eq(TokenType.AUTHORIZATION_CODE))) - .thenReturn(authorization); - - this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) - .params(getTokenRequestParameters(registeredClient, authorization)) - .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) - .param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)) - .andExpect(status().isOk()) - .andExpect(jsonPath("$.refresh_token").isNotEmpty()); - } - private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java index 2eb30a1..d601773 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java @@ -13,27 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; -import java.time.Instant; - import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.http.HttpHeaders; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.crypto.keys.KeyManager; import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -42,27 +37,34 @@ import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** + * Integration tests for the OAuth 2.0 Refresh Token Grant. + * * @author Alexey Nesterov * @since 0.0.3 */ public class OAuth2RefreshTokenGrantTests { - - private static final String TEST_REFRESH_TOKEN = "test-refresh-token"; - private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; @@ -72,8 +74,6 @@ public class OAuth2RefreshTokenGrantTests { @Autowired private MockMvc mvc; - private RegisteredClient registeredClient; - @BeforeClass public static void init() { registeredClientRepository = mock(RegisteredClientRepository.class); @@ -84,33 +84,56 @@ public class OAuth2RefreshTokenGrantTests { public void setup() { reset(registeredClientRepository); reset(authorizationService); - - this.registeredClient = TestRegisteredClients.registeredClient2().build(); - - this.spring.register(OAuth2RefreshTokenGrantTests.AuthorizationServerConfiguration.class).autowire(); } @Test - public void requestWhenRefreshTokenExists() throws Exception { - when(registeredClientRepository.findByClientId(eq(this.registeredClient.getClientId()))) - .thenReturn(this.registeredClient); + public void requestWhenRefreshTokenRequestValidThenReturnAccessTokenResponse() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(this.registeredClient) - .tokens(OAuth2Tokens.builder() - .refreshToken(new OAuth2RefreshToken(TEST_REFRESH_TOKEN, Instant.now(), Instant.now().plusSeconds(60))) - .accessToken(new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(10))) - .build()) - .build(); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); - when(authorizationService.findByToken(TEST_REFRESH_TOKEN, TokenType.REFRESH_TOKEN)) + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) .thenReturn(authorization); this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) - .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()) - .param(OAuth2ParameterNames.REFRESH_TOKEN, TEST_REFRESH_TOKEN) - .with(httpBasic(this.registeredClient.getClientId(), this.registeredClient.getClientSecret()))) + .params(getRefreshTokenRequestParameters(authorization)) + .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( + registeredClient.getClientId(), registeredClient.getClientSecret()))) .andExpect(status().isOk()) - .andExpect(jsonPath("$.access_token").isNotEmpty()); + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store"))) + .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))) + .andExpect(jsonPath("$.access_token").isNotEmpty()) + .andExpect(jsonPath("$.token_type").isNotEmpty()) + .andExpect(jsonPath("$.expires_in").isNotEmpty()) + .andExpect(jsonPath("$.refresh_token").isNotEmpty()) + .andExpect(jsonPath("$.scope").isNotEmpty()); + + verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); + verify(authorizationService).findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN)); + verify(authorizationService).save(any()); + + } + + private static MultiValueMap getRefreshTokenRequestParameters(OAuth2Authorization authorization) { + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()); + parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, authorization.getTokens().getRefreshToken().getTokenValue()); + return parameters; + } + + private static String encodeBasicAuth(String clientId, String secret) throws Exception { + clientId = URLEncoder.encode(clientId, StandardCharsets.UTF_8.name()); + secret = URLEncoder.encode(secret, StandardCharsets.UTF_8.name()); + String credentialsString = clientId + ":" + secret; + byte[] encodedBytes = Base64.getEncoder().encode(credentialsString.getBytes(StandardCharsets.UTF_8)); + return new String(encodedBytes, StandardCharsets.UTF_8); } @EnableWebSecurity diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index c5b4aa3..83807b2 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -143,17 +143,17 @@ public class InMemoryOAuth2AuthorizationServiceTests { } @Test - public void findByTokenAndTokenTypeWhenTokenTypeRefreshTokenThenFound() { - final String refreshTokenValue = "refresh-token"; - OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) - .principalName(PRINCIPAL_NAME) - .tokens(OAuth2Tokens.builder().refreshToken(new OAuth2RefreshToken(refreshTokenValue, Instant.now().plusSeconds(10))).build()) - .build(); - this.authorizationService.save(expectedAuthorization); + public void findByTokenWhenTokenTypeRefreshTokenThenFound() { + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now()); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .principalName(PRINCIPAL_NAME) + .tokens(OAuth2Tokens.builder().refreshToken(refreshToken).build()) + .build(); + this.authorizationService.save(authorization); OAuth2Authorization result = this.authorizationService.findByToken( - refreshTokenValue, TokenType.REFRESH_TOKEN); - assertThat(result).isEqualTo(expectedAuthorization); + refreshToken.getTokenValue(), TokenType.REFRESH_TOKEN); + assertThat(authorization).isEqualTo(result); } @Test diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 805d025..640567f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -15,14 +15,9 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; -import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; - import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; - import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; @@ -38,15 +33,16 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TokenType; -import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.config.TokenSettings; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; @@ -64,9 +60,7 @@ import static org.mockito.Mockito.when; * @author Daniel Garnier-Moiroux */ public class OAuth2AuthorizationCodeAuthenticationProviderTests { - private static final String AUTHORIZATION_CODE = "code"; - private RegisteredClient registeredClient; private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; @@ -74,8 +68,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Before public void setUp() { - this.registeredClient = TestRegisteredClients.registeredClient().build(); - this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient); + this.registeredClientRepository = mock(RegisteredClientRepository.class); this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( @@ -110,8 +103,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( - this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + registeredClient.getClientId(), registeredClient.getClientSecret()); OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) @@ -123,8 +117,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( - this.registeredClient.getClientId(), this.registeredClient.getClientSecret(), null); + registeredClient.getClientId(), registeredClient.getClientSecret(), null); OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) @@ -136,7 +131,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() { - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null); assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) @@ -171,11 +167,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() { - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = @@ -189,9 +186,10 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenInvalidatedCodeThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120)); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization() + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .tokens(OAuth2Tokens.builder() .token(authorizationCode, OAuth2TokenMetadata.builder().invalidated().build()) .build()) @@ -199,7 +197,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = @@ -214,11 +212,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenValidCodeThenReturnAccessToken() { - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = @@ -244,17 +243,24 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); + assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull(); + assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class); assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue(); } @Test - public void authenticateWhenValidCodeThenReturnRefreshToken() { - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + public void authenticateWhenRefreshTokenTimeToLiveConfiguredThenRefreshTokenExpirySet() { + Duration refreshTokenTTL = Duration.ofDays(1); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .tokenSettings(tokenSettings -> tokenSettings.refreshTokenTimeToLive(refreshTokenTTL)) + .build(); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = @@ -269,53 +275,23 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull(); - assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotNull(); + assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); + Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt().plus(refreshTokenTTL); + assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isBetween( + expectedRefreshTokenExpiresAt.minusSeconds(1), expectedRefreshTokenExpiresAt.plusSeconds(1)); } @Test - public void authenticateWhenTokenSettingsHasTimeToLiveThenRefreshTokenHasExpiration() { - Duration testRefreshTokenTTL = Duration.ofDays(1); - Duration defaultRefreshTokenTTL = new TokenSettings().refreshTokenTimeToLive(); - RegisteredClient clientWithRefreshTokenTTLZero = TestRegisteredClients.registeredClient() - .tokenSettings(tokenSettings -> tokenSettings.refreshTokenTimeToLive(testRefreshTokenTTL)) - .build(); - - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) - .thenReturn(authorization); - - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(clientWithRefreshTokenTTLZero); - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( - OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); - OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); - - when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); - - OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = - (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); - - ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); - verify(this.authorizationService).save(authorizationCaptor.capture()); - OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - - assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isAfter(Instant.now().plus(defaultRefreshTokenTTL)); - assertThat(updatedAuthorization.getTokens().getRefreshToken().getExpiresAt()).isAfter(Instant.now().plus(defaultRefreshTokenTTL)); - } - - @Test - public void authenticateWhenRefreshTokenDisabledReturnNullRefreshCode() { - RegisteredClient clientWithRefreshTokenDisabled = TestRegisteredClients - .registeredClient() + public void authenticateWhenRefreshTokenDisabledThenRefreshTokenNull() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .tokenSettings(tokenSettings -> tokenSettings.enableRefreshTokens(false)) .build(); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(clientWithRefreshTokenDisabled); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationCodeAuthenticationToken authentication = diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java index b715219..b43c8d8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -13,29 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.authentication; -import java.time.Instant; -import java.util.HashSet; -import java.util.Set; -import java.util.UUID; - -import org.assertj.core.api.Assertions; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; - -import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.jose.JoseHeaderNames; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TokenType; @@ -43,59 +36,46 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.HashSet; +import java.util.Set; + import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** + * Tests for {@link OAuth2RefreshTokenAuthenticationProvider}. + * * @author Alexey Nesterov * @since 0.0.3 */ public class OAuth2RefreshTokenAuthenticationProviderTests { - - private final String NEW_ACCESS_TOKEN_VALUE = UUID.randomUUID().toString(); - private final String REFRESH_TOKEN_VALUE = UUID.randomUUID().toString(); - - private final RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); - private final OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); - - private final OAuth2AccessToken existingAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, - "old-test-access-token", - Instant.now(), - Instant.now().plusSeconds(10), - this.registeredClient.getScopes()); - - private final OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(this.registeredClient) - .tokens(OAuth2Tokens.builder() - .accessToken(this.existingAccessToken) - .refreshToken(new OAuth2RefreshToken(REFRESH_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(60))) - .build()) - .build(); - private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; - private OAuth2RefreshTokenAuthenticationProvider provider; + private OAuth2RefreshTokenAuthenticationProvider authenticationProvider; @Before public void setUp() { this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); - this.provider = new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, this.jwtEncoder); - - Jwt jwt = Jwt.withTokenValue(NEW_ACCESS_TOKEN_VALUE) - .issuedAt(Instant.now()) - .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) - .build(); - + Jwt jwt = Jwt.withTokenValue("refreshed-access-token") + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .issuedAt(Instant.now()) + .expiresAt(Instant.now().plus(1, ChronoUnit.HOURS)) + .build(); when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt); + this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider( + this.authorizationService, this.jwtEncoder); } @Test - public void constructorWhenAuthorizationServiceNullThenThrowException() { + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(null, this.jwtEncoder)) .isInstanceOf(IllegalArgumentException.class) .extracting(Throwable::getMessage) @@ -103,7 +83,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { } @Test - public void constructorWhenJwtEncoderNullThenThrowException() { + public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, null)) .isInstanceOf(IllegalArgumentException.class) .extracting(Throwable::getMessage) @@ -112,140 +92,122 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { @Test public void supportsWhenSupportedAuthenticationThenTrue() { - assertThat(this.provider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue(); + assertThat(this.authenticationProvider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue(); } @Test public void supportsWhenUnsupportedAuthenticationThenFalse() { - assertThat(this.provider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isFalse(); + assertThat(this.authenticationProvider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isFalse(); } @Test - public void authenticateWhenRefreshTokenExistsThenReturnAuthentication() { - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(this.authorization); + public void authenticateWhenValidRefreshTokenThenReturnAccessToken() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); - OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = - (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(token); - - ArgumentCaptor claimsSetArgumentCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class); - verify(this.jwtEncoder).encode(any(), claimsSetArgumentCaptor.capture()); - - assertThat(claimsSetArgumentCaptor.getValue().getSubject()).isEqualTo(this.authorization.getPrincipalName()); - - assertThat(accessTokenAuthentication.getAccessToken()).isNotNull(); - assertThat(accessTokenAuthentication.getAccessToken().getTokenValue()).isEqualTo(NEW_ACCESS_TOKEN_VALUE); - assertThat(accessTokenAuthentication.getAccessToken().getScopes()).containsAll(this.existingAccessToken.getScopes()); - assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(this.clientPrincipal); - assertThat(accessTokenAuthentication.getRegisteredClient()).isEqualTo(this.registeredClient); - } - - @Test - public void authenticateWhenRefreshTokenExistsThenUpdatesAuthorization() { - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(this.authorization); - - OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); - this.provider.authenticate(token); + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull(); - assertThat(updatedAuthorization.getTokens().getAccessToken().getTokenValue()).isEqualTo(NEW_ACCESS_TOKEN_VALUE); + assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); + assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); + assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotEqualTo(authorization.getTokens().getAccessToken()); + assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); + // By default, refresh token is reused + assertThat(updatedAuthorization.getTokens().getRefreshToken()).isEqualTo(authorization.getTokens().getRefreshToken()); } @Test - public void authenticateWhenClientSetToReuseRefreshTokensThenKeepsRefreshTokenValue() { - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(this.authorization); - - RegisteredClient clientWithReuseTokensTrue = TestRegisteredClients.registeredClient2() - .tokenSettings(tokenSettings -> tokenSettings.reuseRefreshTokens(true)) + public void authenticateWhenReuseRefreshTokensFalseThenReturnNewRefreshToken() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .tokenSettings(tokenSettings -> tokenSettings.reuseRefreshTokens(false)) .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); - OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, new OAuth2ClientAuthenticationToken(clientWithReuseTokensTrue)); - OAuth2AccessTokenAuthenticationToken authentication = (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(token); - - ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); - verify(this.authorizationService).save(authorizationCaptor.capture()); - OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - - assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotNull(); - assertThat(updatedAuthorization.getTokens().getRefreshToken()).isEqualTo(this.authorization.getTokens().getRefreshToken()); - assertThat(authentication.getRefreshToken()).isEqualTo(this.authorization.getTokens().getRefreshToken()); - } - - @Test - public void authenticateWhenClientSetToGenerateNewRefreshTokensThenGenerateNewToken() { - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(this.authorization); - - RegisteredClient clientWithReuseTokensFalse = TestRegisteredClients.registeredClient2() - .tokenSettings(tokenSettings -> tokenSettings.reuseRefreshTokens(false)) - .build(); - - OAuth2RefreshTokenAuthenticationToken token = - new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, new OAuth2ClientAuthenticationToken(clientWithReuseTokensFalse)); - - OAuth2AccessTokenAuthenticationToken authentication = (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(token); - - ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); - verify(this.authorizationService).save(authorizationCaptor.capture()); - OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); - - assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotNull(); - assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotEqualTo(this.authorization.getTokens().getRefreshToken()); - assertThat(authentication.getRefreshToken()).isNotEqualTo(this.authorization.getTokens().getRefreshToken()); - } - - @Test - public void authenticateWhenRefreshTokenHasScopesThenIncludeScopes() { - Set requestedScopes = new HashSet<>(); - requestedScopes.add("email"); - requestedScopes.add("openid"); - - OAuth2RefreshTokenAuthenticationToken tokenWithScopes - = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal, requestedScopes); - - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(this.authorization); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = - (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(tokenWithScopes); + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); - assertThat(accessTokenAuthentication.getAccessToken()).isNotNull(); - assertThat(accessTokenAuthentication.getAccessToken().getScopes()).containsAll(requestedScopes); + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); + assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotEqualTo(authorization.getTokens().getRefreshToken()); } @Test - public void authenticateWhenRefreshTokenHasNotApprovedScopesThenThrowException() { - Set requestedScopes = new HashSet<>(); - requestedScopes.add("email"); - requestedScopes.add("another-scope"); + public void authenticateWhenRequestedScopesAuthorizedThenAccessTokenIncludesScopes() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); - OAuth2RefreshTokenAuthenticationToken tokenWithScopes - = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal, requestedScopes); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); + Set requestedScopes = new HashSet<>(authorizedScopes); + requestedScopes.remove("email"); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, requestedScopes); - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(this.authorization); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); - assertThatThrownBy(() -> this.provider.authenticate(tokenWithScopes)) + assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(requestedScopes); + } + + @Test + public void authenticateWhenRequestedScopesNotAuthorizedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); + Set requestedScopes = new HashSet<>(authorizedScopes); + requestedScopes.add("unauthorized"); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, requestedScopes); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) - .extracting((Throwable e) -> ((OAuth2AuthenticationException) e).getError()) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_SCOPE); } @Test - public void authenticateWhenRefreshTokenDoesNotExistThenThrowException() { - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(null); + public void authenticateWhenInvalidRefreshTokenThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + "invalid", clientPrincipal); - OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); - assertThatThrownBy(() -> this.provider.authenticate(token)) + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") @@ -253,11 +215,14 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { } @Test - public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { - OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient.getClientId(), null); - OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, clientPrincipal); + public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken( + registeredClient.getClientId(), registeredClient.getClientSecret()); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + "refresh-token", clientPrincipal); - Assertions.assertThatThrownBy(() -> this.provider.authenticate(token)) + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") @@ -265,23 +230,83 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { } @Test - public void authenticateWhenRefreshTokenHasExpiredThenThrowException() { - OAuth2RefreshToken expiredRefreshToken = new OAuth2RefreshToken(REFRESH_TOKEN_VALUE, Instant.now().minusSeconds(120), Instant.now().minusSeconds(60)); - OAuth2Authorization authorizationWithExpiredRefreshToken = - OAuth2Authorization - .from(this.authorization) - .tokens(OAuth2Tokens.from(this.authorization.getTokens()).refreshToken(expiredRefreshToken).build()) - .build(); + public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), registeredClient.getClientSecret(), null); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + "refresh-token", clientPrincipal); - OAuth2RefreshTokenAuthenticationToken token - = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); - - when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) - .thenReturn(authorizationWithExpiredRefreshToken); - - assertThatThrownBy(() -> this.provider.authenticate(token)) + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) .isInstanceOf(OAuth2AuthenticationException.class) - .extracting((Throwable e) -> ((OAuth2AuthenticationException) e).getError()) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenRefreshTokenIssuedToAnotherClientThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + TestRegisteredClients.registeredClient2().build()); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenClientNotAuthorizedToRefreshTokenThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantTypes(grantTypes -> grantTypes.remove(AuthorizationGrantType.REFRESH_TOKEN)) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT); + } + + @Test + public void authenticateWhenExpiredRefreshTokenThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + OAuth2RefreshToken expiredRefreshToken = new OAuth2RefreshToken( + "expired-refresh-token", Instant.now().minusSeconds(120), Instant.now().minusSeconds(60)); + OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens()).refreshToken(expiredRefreshToken).build(); + authorization = OAuth2Authorization.from(authorization).tokens(tokens).build(); + when(this.authorizationService.findByToken( + eq(authorization.getTokens().getRefreshToken().getTokenValue()), + eq(TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java index c665748..2dafaf9 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java @@ -13,49 +13,60 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.server.authorization.authentication; +import org.junit.Test; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import org.junit.Test; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; /** + * Tests for {@link OAuth2RefreshTokenAuthenticationToken}. + * * @author Alexey Nesterov * @since 0.0.3 */ public class OAuth2RefreshTokenAuthenticationTokenTests { + private final OAuth2ClientAuthenticationToken clientPrincipal = + new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build()); @Test - public void constructorWhenClientPrincipalNullThrowException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("test", null)) + public void constructorWhenRefreshTokenNullOrEmptyThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, this.clientPrincipal)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("refreshToken cannot be empty"); + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", this.clientPrincipal)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("refreshToken cannot be empty"); + } + + @Test + public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientPrincipal cannot be null"); } @Test - public void constructorWhenRefreshTokenNullOrEmptyThrowException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, mock(OAuth2ClientAuthenticationToken.class))) + public void constructorWhenScopesNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", this.clientPrincipal, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("refreshToken cannot be empty"); - - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", mock(OAuth2ClientAuthenticationToken.class))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("refreshToken cannot be empty"); + .hasMessage("scopes cannot be null"); } @Test - public void constructorWhenGettingScopesThenReturnRequestedScopes() { + public void constructorWhenScopesProvidedThenCreated() { Set expectedScopes = new HashSet<>(Arrays.asList("scope-a", "scope-b")); - OAuth2RefreshTokenAuthenticationToken token - = new OAuth2RefreshTokenAuthenticationToken("test", mock(OAuth2ClientAuthenticationToken.class), expectedScopes); - - assertThat(token.getScopes()).containsAll(expectedScopes); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + "refresh-token", this.clientPrincipal, expectedScopes); + assertThat(authentication.getRefreshToken()).isEqualTo("refresh-token"); + assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getScopes()).isEqualTo(expectedScopes); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java index e6421d4..33e0001 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -28,6 +28,7 @@ public class TestRegisteredClients { .clientId("client-1") .clientSecret("secret") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .redirectUri("https://example.com") .scope("openid") @@ -40,6 +41,7 @@ public class TestRegisteredClients { .clientId("client-2") .clientSecret("secret") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .redirectUri("https://example.com") diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java index 0076cbf..b604ab7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java @@ -35,7 +35,7 @@ public class TokenSettingsTests { assertThat(tokenSettings.settings()).hasSize(4); assertThat(tokenSettings.accessTokenTimeToLive()).isEqualTo(Duration.ofMinutes(5)); assertThat(tokenSettings.enableRefreshTokens()).isTrue(); - assertThat(tokenSettings.reuseRefreshTokens()).isEqualTo(true); + assertThat(tokenSettings.reuseRefreshTokens()).isTrue(); assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(Duration.ofMinutes(60)); } @@ -54,27 +54,44 @@ public class TokenSettingsTests { } @Test - public void enableRefreshTokenWhenFalseThenSet() { + public void accessTokenTimeToLiveWhenNullOrZeroOrNegativeThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new TokenSettings().accessTokenTimeToLive(null)) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("accessTokenTimeToLive cannot be null"); + + assertThatThrownBy(() -> new TokenSettings().accessTokenTimeToLive(Duration.ZERO)) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("accessTokenTimeToLive must be greater than Duration.ZERO"); + + assertThatThrownBy(() -> new TokenSettings().accessTokenTimeToLive(Duration.ofSeconds(-10))) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("accessTokenTimeToLive must be greater than Duration.ZERO"); + } + + @Test + public void enableRefreshTokensWhenFalseThenSet() { TokenSettings tokenSettings = new TokenSettings().enableRefreshTokens(false); assertThat(tokenSettings.enableRefreshTokens()).isFalse(); } @Test - public void reuseRefreshTokensWhenProvidedThenSet() { - boolean reuseRefreshTokens = true; - TokenSettings tokenSettings = new TokenSettings().reuseRefreshTokens(reuseRefreshTokens); - assertThat(tokenSettings.reuseRefreshTokens()).isEqualTo(reuseRefreshTokens); + public void reuseRefreshTokensWhenFalseThenSet() { + TokenSettings tokenSettings = new TokenSettings().reuseRefreshTokens(false); + assertThat(tokenSettings.reuseRefreshTokens()).isFalse(); } @Test public void refreshTokenTimeToLiveWhenProvidedThenSet() { - Duration refresTokenTimeToLive = Duration.ofDays(10); - TokenSettings tokenSettings = new TokenSettings().refreshTokenTimeToLive(refresTokenTimeToLive); - assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(refresTokenTimeToLive); + Duration refreshTokenTimeToLive = Duration.ofDays(10); + TokenSettings tokenSettings = new TokenSettings().refreshTokenTimeToLive(refreshTokenTimeToLive); + assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(refreshTokenTimeToLive); } @Test - public void refreshTokenTimeToLiveWhenZeroOrNegativeThenThrowException() { + public void refreshTokenTimeToLiveWhenNullOrZeroOrNegativeThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new TokenSettings().refreshTokenTimeToLive(null)) .isInstanceOf(IllegalArgumentException.class) .extracting(Throwable::getMessage) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index 2f880ee..9235c0a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -250,15 +250,9 @@ public class OAuth2TokenEndpointFilterTests { } @Test - public void doFilterWhenTokenRequestMultipleScopeThenInvalidRequestError() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); - Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); - - SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - securityContext.setAuthentication(clientPrincipal); - SecurityContextHolder.setContext(securityContext); - - MockHttpServletRequest request = createClientCredentialsTokenRequest(registeredClient); + public void doFilterWhenClientCredentialsTokenRequestMultipleScopeThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createClientCredentialsTokenRequest( + TestRegisteredClients.registeredClient2().build()); request.addParameter(OAuth2ParameterNames.SCOPE, "profile"); doFilterWhenTokenRequestInvalidParameterThenError( @@ -313,16 +307,45 @@ public class OAuth2TokenEndpointFilterTests { assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); } + @Test + public void doFilterWhenRefreshTokenRequestMissingRefreshTokenThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createRefreshTokenTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.removeParameter(OAuth2ParameterNames.REFRESH_TOKEN); + + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.REFRESH_TOKEN, OAuth2ErrorCodes.INVALID_REQUEST, request); + } + + @Test + public void doFilterWhenRefreshTokenRequestMultipleRefreshTokenThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createRefreshTokenTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, "refresh-token-2"); + + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.REFRESH_TOKEN, OAuth2ErrorCodes.INVALID_REQUEST, request); + } + + @Test + public void doFilterWhenRefreshTokenRequestMultipleScopeThenInvalidRequestError() throws Exception { + MockHttpServletRequest request = createRefreshTokenTokenRequest( + TestRegisteredClients.registeredClient().build()); + request.addParameter(OAuth2ParameterNames.SCOPE, "profile"); + + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.SCOPE, OAuth2ErrorCodes.INVALID_REQUEST, request); + } + @Test public void doFilterWhenRefreshTokenRequestValidThenAccessTokenResponse() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(), Instant.now().plus(Duration.ofHours(1)), new HashSet<>(Arrays.asList("scope1", "scope2"))); - String refreshTokenValue = "refresh-token"; - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(refreshTokenValue, Instant.now()); + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now()); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = new OAuth2AccessTokenAuthenticationToken( registeredClient, clientPrincipal, accessToken, refreshToken); @@ -333,7 +356,7 @@ public class OAuth2TokenEndpointFilterTests { securityContext.setAuthentication(clientPrincipal); SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest request = createRefreshTokenTokenRequest(registeredClient, refreshTokenValue, null); + MockHttpServletRequest request = createRefreshTokenTokenRequest(registeredClient); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -341,12 +364,13 @@ public class OAuth2TokenEndpointFilterTests { verifyNoInteractions(filterChain); - ArgumentCaptor argumentCaptor = + ArgumentCaptor refreshTokenAuthenticationCaptor = ArgumentCaptor.forClass(OAuth2RefreshTokenAuthenticationToken.class); - verify(this.authenticationManager).authenticate(argumentCaptor.capture()); + verify(this.authenticationManager).authenticate(refreshTokenAuthenticationCaptor.capture()); OAuth2RefreshTokenAuthenticationToken refreshTokenAuthenticationToken = - argumentCaptor.getValue(); + refreshTokenAuthenticationCaptor.getValue(); + assertThat(refreshTokenAuthenticationToken.getRefreshToken()).isEqualTo(refreshToken.getTokenValue()); assertThat(refreshTokenAuthenticationToken.getPrincipal()).isEqualTo(clientPrincipal); assertThat(refreshTokenAuthenticationToken.getScopes()).isEqualTo(registeredClient.getScopes()); @@ -361,6 +385,9 @@ public class OAuth2TokenEndpointFilterTests { assertThat(accessTokenResult.getExpiresAt()).isBetween( accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1)); assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); + + OAuth2RefreshToken refreshTokenResult = accessTokenResponse.getRefreshToken(); + assertThat(refreshTokenResult.getTokenValue()).isEqualTo(refreshToken.getTokenValue()); } private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode, @@ -419,19 +446,15 @@ public class OAuth2TokenEndpointFilterTests { return request; } - private static MockHttpServletRequest createRefreshTokenTokenRequest(RegisteredClient registeredClient, String refreshToken, String scope) { + private static MockHttpServletRequest createRefreshTokenTokenRequest(RegisteredClient registeredClient) { String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()); - request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, refreshToken); - if (scope == null) { - request.addParameter(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); - } else { - request.addParameter(OAuth2ParameterNames.SCOPE, scope); - } + request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, "refresh-token"); + request.addParameter(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); return request; }