diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index 80fe069..564baf0 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -31,6 +31,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.web.JwkSetEndpointFilter; @@ -149,6 +150,10 @@ public final class OAuth2AuthorizationServerConfigurer authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); - - JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() - .issuer(issuer) - .subject(authorization.getPrincipalName()) - .audience(Collections.singletonList(registeredClient.getClientId())) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .notBefore(issuedAt) - .claim(OAuth2ParameterNames.SCOPE, authorizedScopes) - .build(); - - Jwt jwt = this.jwtEncoder.encode(joseHeader, jwtClaimsSet); - + Jwt jwt = OAuth2TokenIssuerUtil + .issueJwtAccessToken(this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), authorizedScopes); OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), authorizedScopes); + OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens()) + .accessToken(accessToken); + + OAuth2RefreshToken refreshToken = null; + if (registeredClient.getTokenSettings().enableRefreshTokens()) { + refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(registeredClient.getTokenSettings().refreshTokenTimeToLive()); + tokensBuilder.refreshToken(refreshToken); + } + + OAuth2Tokens tokens = tokensBuilder.build(); authorization = OAuth2Authorization.from(authorization) - .tokens(OAuth2Tokens.from(authorization.getTokens()) - .accessToken(accessToken) - .build()) + .tokens(tokens) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .build(); @@ -167,7 +145,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica this.authorizationService.save(authorization); - return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken); + return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken, refreshToken); } @Override diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index f07cb46..9fe31a3 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -22,11 +22,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.jose.JoseHeader; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; @@ -36,12 +32,6 @@ import org.springframework.security.oauth2.server.authorization.token.OAuth2Toke import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.net.MalformedURLException; -import java.net.URI; -import java.net.URL; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Collections; import java.util.LinkedHashSet; import java.util.Set; import java.util.stream.Collectors; @@ -101,29 +91,8 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes()); } - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); - - // TODO Allow configuration for issuer claim - URL issuer = null; - try { - issuer = URI.create("https://oauth2.provider.com").toURL(); - } catch (MalformedURLException e) { } - - Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token time-to-live - - JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() - .issuer(issuer) - .subject(clientPrincipal.getName()) - .audience(Collections.singletonList(registeredClient.getClientId())) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .notBefore(issuedAt) - .claim(OAuth2ParameterNames.SCOPE, scopes) - .build(); - - Jwt jwt = this.jwtEncoder.encode(joseHeader, jwtClaimsSet); - + Jwt jwt = OAuth2TokenIssuerUtil + .issueJwtAccessToken(this.jwtEncoder, clientPrincipal.getName(), registeredClient.getClientId(), scopes); OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java new file mode 100644 index 0000000..7e34663 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java @@ -0,0 +1,136 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.time.Instant; +import java.util.Set; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.config.TokenSettings; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; +import org.springframework.util.Assert; + +/** + * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant. + * + * @author Alexey Nesterov + * @since 0.0.3 + * @see OAuth2RefreshTokenAuthenticationToken + * @see OAuth2AccessTokenAuthenticationToken + * @see OAuth2AuthorizationService + * @see JwtEncoder + * @see Section 1.5 Refresh Token + * @see Section 6 Refreshing an Access Token + */ + +public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider { + + private final OAuth2AuthorizationService authorizationService; + private final JwtEncoder jwtEncoder; + + public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); + + this.authorizationService = authorizationService; + this.jwtEncoder = jwtEncoder; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2RefreshTokenAuthenticationToken refreshTokenAuthentication = + (OAuth2RefreshTokenAuthenticationToken) authentication; + + OAuth2ClientAuthenticationToken clientPrincipal = null; + if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(refreshTokenAuthentication.getPrincipal().getClass())) { + clientPrincipal = (OAuth2ClientAuthenticationToken) refreshTokenAuthentication.getPrincipal(); + } + + if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } + + OAuth2Authorization authorization = this.authorizationService.findByToken(refreshTokenAuthentication.getRefreshToken(), TokenType.REFRESH_TOKEN); + if (authorization == null) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + + Instant refreshTokenExpiration = authorization.getTokens().getRefreshToken().getExpiresAt(); + if (refreshTokenExpiration.isBefore(Instant.now())) { + // as per https://tools.ietf.org/html/rfc6749#section-5.2 + // invalid_grant: The provided authorization grant (e.g., authorization + // code, resource owner credentials) or refresh token is invalid, expired, revoked [...]. + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); + } + + RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); + + // https://tools.ietf.org/html/rfc6749#section-6 + // The requested scope MUST NOT include any scope not originally granted by the resource owner, + // and if omitted is treated as equal to the scope originally granted by the resource owner. + Set refreshTokenScopes = refreshTokenAuthentication.getScopes(); + Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); + if (!authorizedScopes.containsAll(refreshTokenScopes)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE)); + } + + if (refreshTokenScopes.isEmpty()) { + refreshTokenScopes = authorizedScopes; + } + + Jwt jwt = OAuth2TokenIssuerUtil + .issueJwtAccessToken(this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), refreshTokenScopes); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), refreshTokenScopes); + + TokenSettings tokenSettings = registeredClient.getTokenSettings(); + OAuth2RefreshToken refreshToken; + if (tokenSettings.reuseRefreshTokens()) { + refreshToken = authorization.getTokens().getRefreshToken(); + } else { + refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(tokenSettings.refreshTokenTimeToLive()); + } + + authorization = OAuth2Authorization.from(authorization) + .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) + .tokens(OAuth2Tokens.builder().accessToken(accessToken).refreshToken(refreshToken).build()) + .build(); + + this.authorizationService.save(authorization); + + return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken, refreshToken); + } + + @Override + public boolean supports(Class authentication) { + return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication); + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java new file mode 100644 index 0000000..30770e3 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Collections; +import java.util.Set; + +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.util.Assert; + +/** + * An {@link Authentication} implementation used for the OAuth 2.0 Refresh Token Grant. + * + * @author Alexey Nesterov + * @since 0.0.3 + * @see AbstractAuthenticationToken + * @see OAuth2RefreshTokenAuthenticationProvider + * @see OAuth2ClientAuthenticationToken + */ +public class OAuth2RefreshTokenAuthenticationToken extends AbstractAuthenticationToken { + + private final Authentication clientPrincipal; + private final String refreshToken; + private final Set scopes; + + /** + * Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters. + * + * @param refreshToken refresh token value + * @param clientPrincipal the authenticated client principal + */ + public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal) { + this(clientPrincipal, refreshToken, Collections.emptySet()); + } + + /** + * Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters. + * + * @param clientPrincipal the authenticated client principal + * @param refreshToken refresh token value + * @param requestedScopes scopes requested by refresh token + */ + public OAuth2RefreshTokenAuthenticationToken(Authentication clientPrincipal, String refreshToken, Set requestedScopes) { + super(Collections.emptySet()); + + Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); + Assert.hasText(refreshToken, "refreshToken cannot be null or empty"); + + this.clientPrincipal = clientPrincipal; + this.refreshToken = refreshToken; + this.scopes = requestedScopes; + } + + @Override + public Object getCredentials() { + return ""; + } + + @Override + public Object getPrincipal() { + return this.clientPrincipal; + } + + public String getRefreshToken() { + return this.refreshToken; + } + + /** + * Returns the requested scope(s). + * + * @return the requested scope(s), or an empty {@code Set} if not available + */ + public Set getScopes() { + return this.scopes; + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java new file mode 100644 index 0000000..b473c14 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.Collections; +import java.util.Set; + +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.JoseHeader; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; + +/** + * @author Alexey Nesterov + * @since 0.0.3 + */ +class OAuth2TokenIssuerUtil { + + private static final StringKeyGenerator CODE_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); + + static Jwt issueJwtAccessToken(JwtEncoder jwtEncoder, String subject, String audience, Set scopes) { + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + + // TODO Allow configuration for issuer claim + URL issuer = null; + try { + issuer = URI.create("https://oauth2.provider.com").toURL(); + } catch (MalformedURLException e) { } + + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token time-to-live + + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() + .issuer(issuer) + .subject(subject) + .audience(Collections.singletonList(audience)) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .notBefore(issuedAt) + .claim(OAuth2ParameterNames.SCOPE, scopes) + .build(); + + return jwtEncoder.encode(joseHeader, jwtClaimsSet); + } + + static OAuth2RefreshToken issueRefreshToken(Duration refreshTokenTimeToLive) { + Instant issuedAt = Instant.now(); + Instant refreshTokenExpiresAt = issuedAt.plus(refreshTokenTimeToLive); + + return new OAuth2RefreshToken(CODE_GENERATOR.generateKey(), issuedAt, refreshTokenExpiresAt); + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java index c2fdf43..ad69a73 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java @@ -19,6 +19,8 @@ import java.time.Duration; import java.util.HashMap; import java.util.Map; +import org.springframework.util.Assert; + /** * A facility for token configuration settings. * @@ -29,6 +31,9 @@ import java.util.Map; public class TokenSettings extends Settings { private static final String TOKEN_SETTING_BASE = "spring.security.oauth2.authorization-server.token."; public static final String ACCESS_TOKEN_TIME_TO_LIVE = TOKEN_SETTING_BASE.concat("access-token-time-to-live"); + public static final String ENABLE_REFRESH_TOKENS = TOKEN_SETTING_BASE.concat("enable-refresh-tokens"); + public static final String REUSE_REFRESH_TOKENS = TOKEN_SETTING_BASE.concat("reuse-refresh-tokens"); + public static final String REFRESH_TOKEN_TIME_TO_LIVE = TOKEN_SETTING_BASE.concat("refresh-token-time-to-live"); /** * Constructs a {@code TokenSettings}. @@ -66,9 +71,75 @@ public class TokenSettings extends Settings { return this; } + /** + * Returns {@code true} if refresh tokens support is enabled. + * This include generation of refresh token as a part of Authorization Code Grant flow and support of Refresh Token + * Grant flow. The default is {@code true}. + * + * @return {@code true} if the client support refresh token, {@code false} otherwise + */ + public boolean enableRefreshTokens() { + return setting(ENABLE_REFRESH_TOKENS); + } + + /** + * Set to {@code true} to enable refresh tokens support. + * This include generation of refresh token as a part of Authorization Code Grant flow and support of Refresh Token + * Grant flow. + * + * @param enableRefreshTokens {@code true} to enable refresh token grant support, {@code false} otherwise + * @return the {@link TokenSettings} + */ + public TokenSettings enableRefreshTokens(boolean enableRefreshTokens) { + setting(ENABLE_REFRESH_TOKENS, enableRefreshTokens); + return this; + } + + /** + * Returns {@code true} if existing refresh token is re-used when a new access token is requested via Refresh Token grant, + * or {@code false} if a new refresh token is generated. + * The default is {@code false}. + */ + public boolean reuseRefreshTokens() { + return setting(REUSE_REFRESH_TOKENS); + } + + /** + * Set to {@code true} to re-use existing refresh token when new access token is requested via Refresh Token grant, + * or to {@code false} to generate a new refresh token. + * @param reuseRefreshTokens {@code true} to re-use existing refresh token, {@code false} to generate a new one + */ + public TokenSettings reuseRefreshTokens(boolean reuseRefreshTokens) { + setting(REUSE_REFRESH_TOKENS, reuseRefreshTokens); + return this; + } + + /** + * Returns refresh token time-to-live. The default is 60 minutes. Always greater than {@code Duration.ZERO}. + * @return refresh token time-to-live + */ + public Duration refreshTokenTimeToLive() { + return setting(REFRESH_TOKEN_TIME_TO_LIVE); + } + + /** + * Sets refresh token time-to-live. + * @param refreshTokenTimeToLive refresh token time-to-live. Has to be greater than {@code Duration.ZERO}. + */ + public TokenSettings refreshTokenTimeToLive(Duration refreshTokenTimeToLive) { + Assert.notNull(refreshTokenTimeToLive, "refreshTokenTimeToLive cannot be null"); + Assert.isTrue(refreshTokenTimeToLive.getSeconds() > 0, "refreshTokenTimeToLive has to be greater than Duration.ZERO"); + + setting(REFRESH_TOKEN_TIME_TO_LIVE, refreshTokenTimeToLive); + return this; + } + protected static Map defaultSettings() { Map settings = new HashMap<>(); settings.put(ACCESS_TOKEN_TIME_TO_LIVE, Duration.ofMinutes(5)); + settings.put(ENABLE_REFRESH_TOKENS, true); + settings.put(REUSE_REFRESH_TOKENS, false); + settings.put(REFRESH_TOKEN_TIME_TO_LIVE, Duration.ofMinutes(60)); return settings; } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 568991f..ed9d8ee 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; @@ -36,6 +37,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationToken; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -129,6 +131,7 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { Map> converters = new HashMap<>(); converters.put(AuthorizationGrantType.AUTHORIZATION_CODE, new AuthorizationCodeAuthenticationConverter()); converters.put(AuthorizationGrantType.CLIENT_CREDENTIALS, new ClientCredentialsAuthenticationConverter()); + converters.put(AuthorizationGrantType.REFRESH_TOKEN, new RefreshTokenAuthenticationConverter()); this.authorizationGrantAuthenticationConverter = new DelegatingAuthorizationGrantAuthenticationConverter(converters); } @@ -154,7 +157,7 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication); - sendAccessTokenResponse(response, accessTokenAuthentication.getAccessToken()); + sendAccessTokenResponse(response, accessTokenAuthentication.getAccessToken(), accessTokenAuthentication.getRefreshToken()); } catch (OAuth2AuthenticationException ex) { SecurityContextHolder.clearContext(); @@ -162,7 +165,7 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { } } - private void sendAccessTokenResponse(HttpServletResponse response, OAuth2AccessToken accessToken) throws IOException { + private void sendAccessTokenResponse(HttpServletResponse response, OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken) throws IOException { OAuth2AccessTokenResponse.Builder builder = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) .tokenType(accessToken.getTokenType()) @@ -170,6 +173,9 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) { builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt())); } + if (refreshToken != null) { + builder.refreshToken(refreshToken.getTokenValue()); + } OAuth2AccessTokenResponse accessTokenResponse = builder.build(); ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse); @@ -258,4 +264,41 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); } } + + private static class RefreshTokenAuthenticationConverter implements Converter { + + @Override + public Authentication convert(HttpServletRequest request) { + // grant_type (REQUIRED) + String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) { + return null; + } + + Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + + // refresh token (REQUIRED) + String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN); + if (StringUtils.hasText(refreshToken) && + parameters.get(OAuth2ParameterNames.REFRESH_TOKEN).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REFRESH_TOKEN); + } + + // scope (OPTIONAL) + String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope) && + parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE); + } + if (StringUtils.hasText(scope)) { + Set requestedScopes = new HashSet<>( + Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); + return new OAuth2RefreshTokenAuthenticationToken(clientPrincipal, refreshToken, requestedScopes); + } + + return new OAuth2RefreshTokenAuthenticationToken(refreshToken, clientPrincipal); + } + } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 19711a8..e212b1d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -217,6 +217,46 @@ public class OAuth2AuthorizationCodeGrantTests { verify(authorizationService, times(2)).save(any()); } + @Test + public void requestWhenPublicClientWithRefreshThenReturnRefreshToken() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients + .registeredClient() + .clientSecret(null) + .tokenSettings(tokenSettings -> tokenSettings.enableRefreshTokens(true)) + .build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .params(getAuthorizationRequestParameters(registeredClient)) + .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .with(user("user"))) + .andExpect(status().is3xxRedirection()) + .andReturn(); + assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); + + verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization authorization = authorizationCaptor.getValue(); + + when(authorizationService.findByToken( + eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), + eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorization)) + .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) + .param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.refresh_token").isNotEmpty()); + } + private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java new file mode 100644 index 0000000..2eb30a1 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; + +import java.time.Instant; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; +import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; +import org.springframework.test.web.servlet.MockMvc; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * @author Alexey Nesterov + * @since 0.0.3 + */ +public class OAuth2RefreshTokenGrantTests { + + private static final String TEST_REFRESH_TOKEN = "test-refresh-token"; + + private static RegisteredClientRepository registeredClientRepository; + private static OAuth2AuthorizationService authorizationService; + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Autowired + private MockMvc mvc; + + private RegisteredClient registeredClient; + + @BeforeClass + public static void init() { + registeredClientRepository = mock(RegisteredClientRepository.class); + authorizationService = mock(OAuth2AuthorizationService.class); + } + + @Before + public void setup() { + reset(registeredClientRepository); + reset(authorizationService); + + this.registeredClient = TestRegisteredClients.registeredClient2().build(); + + this.spring.register(OAuth2RefreshTokenGrantTests.AuthorizationServerConfiguration.class).autowire(); + } + + @Test + public void requestWhenRefreshTokenExists() throws Exception { + when(registeredClientRepository.findByClientId(eq(this.registeredClient.getClientId()))) + .thenReturn(this.registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(this.registeredClient) + .tokens(OAuth2Tokens.builder() + .refreshToken(new OAuth2RefreshToken(TEST_REFRESH_TOKEN, Instant.now(), Instant.now().plusSeconds(60))) + .accessToken(new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(10))) + .build()) + .build(); + + when(authorizationService.findByToken(TEST_REFRESH_TOKEN, TokenType.REFRESH_TOKEN)) + .thenReturn(authorization); + + this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()) + .param(OAuth2ParameterNames.REFRESH_TOKEN, TEST_REFRESH_TOKEN) + .with(httpBasic(this.registeredClient.getClientId(), this.registeredClient.getClientSecret()))) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.access_token").isNotEmpty()); + } + + @EnableWebSecurity + @Import(OAuth2AuthorizationServerConfiguration.class) + static class AuthorizationServerConfiguration { + + @Bean + RegisteredClientRepository registeredClientRepository() { + return registeredClientRepository; + } + + @Bean + OAuth2AuthorizationService authorizationService() { + return authorizationService; + } + + @Bean + KeyManager keyManager() { return new StaticKeyGeneratingKeyManager(); } + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index 0984450..c5b4aa3 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization; import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; @@ -141,6 +142,20 @@ public class InMemoryOAuth2AuthorizationServiceTests { assertThat(authorization).isEqualTo(result); } + @Test + public void findByTokenAndTokenTypeWhenTokenTypeRefreshTokenThenFound() { + final String refreshTokenValue = "refresh-token"; + OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .principalName(PRINCIPAL_NAME) + .tokens(OAuth2Tokens.builder().refreshToken(new OAuth2RefreshToken(refreshTokenValue, Instant.now().plusSeconds(10))).build()) + .build(); + this.authorizationService.save(expectedAuthorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + refreshTokenValue, TokenType.REFRESH_TOKEN); + assertThat(result).isEqualTo(expectedAuthorization); + } + @Test public void findByTokenWhenTokenDoesNotExistThenNull() { OAuth2Authorization result = this.authorizationService.findByToken( diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index 717dec8..ff1733c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization; import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; @@ -39,6 +40,7 @@ public class OAuth2AuthorizationTests { private static final String PRINCIPAL_NAME = "principal"; private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); + private static final OAuth2RefreshToken REFRESH_TOKEN = new OAuth2RefreshToken("refresh-token", Instant.now()); private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode( "code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES)); @@ -101,12 +103,13 @@ public class OAuth2AuthorizationTests { public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build()) + .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).refreshToken(REFRESH_TOKEN).build()) .build(); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); assertThat(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isEqualTo(AUTHORIZATION_CODE); assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN); + assertThat(authorization.getTokens().getRefreshToken()).isEqualTo(REFRESH_TOKEN); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index c9cd9ce..805d025 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -15,9 +15,14 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; @@ -37,12 +42,11 @@ import org.springframework.security.oauth2.server.authorization.client.InMemoryR import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.config.TokenSettings; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; -import java.time.Instant; -import java.time.temporal.ChronoUnit; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; @@ -60,6 +64,7 @@ import static org.mockito.Mockito.when; * @author Daniel Garnier-Moiroux */ public class OAuth2AuthorizationCodeAuthenticationProviderTests { + private static final String AUTHORIZATION_CODE = "code"; private RegisteredClient registeredClient; private RegisteredClientRepository registeredClientRepository; @@ -230,6 +235,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { Set scopes = jwtClaimsSet.getClaim(OAuth2ParameterNames.SCOPE); assertThat(scopes).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES)); + assertThat(jwtClaimsSet.getSubject()).isEqualTo(authorization.getPrincipalName()); ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(this.authorizationService).save(authorizationCaptor.capture()); @@ -242,6 +248,87 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue(); } + @Test + public void authenticateWhenValidCodeThenReturnRefreshToken() { + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); + + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull(); + assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotNull(); + } + + @Test + public void authenticateWhenTokenSettingsHasTimeToLiveThenRefreshTokenHasExpiration() { + Duration testRefreshTokenTTL = Duration.ofDays(1); + Duration defaultRefreshTokenTTL = new TokenSettings().refreshTokenTimeToLive(); + RegisteredClient clientWithRefreshTokenTTLZero = TestRegisteredClients.registeredClient() + .tokenSettings(tokenSettings -> tokenSettings.refreshTokenTimeToLive(testRefreshTokenTTL)) + .build(); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(clientWithRefreshTokenTTLZero); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); + + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isAfter(Instant.now().plus(defaultRefreshTokenTTL)); + assertThat(updatedAuthorization.getTokens().getRefreshToken().getExpiresAt()).isAfter(Instant.now().plus(defaultRefreshTokenTTL)); + } + + @Test + public void authenticateWhenRefreshTokenDisabledReturnNullRefreshCode() { + RegisteredClient clientWithRefreshTokenDisabled = TestRegisteredClients + .registeredClient() + .tokenSettings(tokenSettings -> tokenSettings.enableRefreshTokens(false)) + .build(); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(clientWithRefreshTokenDisabled); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); + + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + assertThat(accessTokenAuthentication.getRefreshToken()).isNull(); + } + private static Jwt createJwt() { Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java new file mode 100644 index 0000000..9b1b8fc --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -0,0 +1,288 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.time.Instant; +import java.util.HashSet; +import java.util.Set; +import java.util.UUID; + +import org.assertj.core.api.Assertions; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.jose.JoseHeaderNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * @author Alexey Nesterov + * @since 0.0.3 + */ +public class OAuth2RefreshTokenAuthenticationProviderTests { + + private final String NEW_ACCESS_TOKEN_VALUE = UUID.randomUUID().toString(); + private final String REFRESH_TOKEN_VALUE = UUID.randomUUID().toString(); + + private final RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + private final OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); + + private final OAuth2AccessToken existingAccessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, + "old-test-access-token", + Instant.now(), + Instant.now().plusSeconds(10), + this.registeredClient.getScopes()); + + private final OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(this.registeredClient) + .tokens(OAuth2Tokens.builder() + .accessToken(this.existingAccessToken) + .refreshToken(new OAuth2RefreshToken(REFRESH_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(60))) + .build()) + .build(); + + private OAuth2AuthorizationService authorizationService; + private JwtEncoder jwtEncoder; + private OAuth2RefreshTokenAuthenticationProvider provider; + + @Before + public void setUp() { + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.jwtEncoder = mock(JwtEncoder.class); + this.provider = new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, this.jwtEncoder); + + Jwt jwt = Jwt.withTokenValue(NEW_ACCESS_TOKEN_VALUE) + .issuedAt(Instant.now()) + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .build(); + + when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(null, this.jwtEncoder)) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("authorizationService cannot be null"); + } + + @Test + public void constructorWhenJwtEncoderNullThenThrowException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, null)) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("jwtEncoder cannot be null"); + } + + @Test + public void supportsWhenSupportedAuthenticationThenTrue() { + assertThat(this.provider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue(); + } + + @Test + public void supportsWhenUnsupportedAuthenticationThenFalse() { + assertThat(this.provider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isFalse(); + } + + @Test + public void authenticateWhenRefreshTokenExistsThenReturnAuthentication() { + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(this.authorization); + + OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(token); + + ArgumentCaptor claimsSetArgumentCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class); + verify(this.jwtEncoder).encode(any(), claimsSetArgumentCaptor.capture()); + + assertThat(claimsSetArgumentCaptor.getValue().getSubject()).isEqualTo(this.authorization.getPrincipalName()); + + assertThat(accessTokenAuthentication.getAccessToken()).isNotNull(); + assertThat(accessTokenAuthentication.getAccessToken().getTokenValue()).isEqualTo(NEW_ACCESS_TOKEN_VALUE); + assertThat(accessTokenAuthentication.getAccessToken().getScopes()).containsAll(this.existingAccessToken.getScopes()); + assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(this.clientPrincipal); + assertThat(accessTokenAuthentication.getRegisteredClient()).isEqualTo(this.registeredClient); + } + + @Test + public void authenticateWhenRefreshTokenExistsThenUpdatesAuthorization() { + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(this.authorization); + + OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); + this.provider.authenticate(token); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull(); + assertThat(updatedAuthorization.getTokens().getAccessToken().getTokenValue()).isEqualTo(NEW_ACCESS_TOKEN_VALUE); + } + + @Test + public void authenticateWhenClientSetToReuseRefreshTokensThenKeepsRefreshTokenValue() { + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(this.authorization); + + RegisteredClient clientWithReuseTokensTrue = TestRegisteredClients.registeredClient() + .tokenSettings(tokenSettings -> tokenSettings.reuseRefreshTokens(true)) + .build(); + + OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, new OAuth2ClientAuthenticationToken(clientWithReuseTokensTrue)); + OAuth2AccessTokenAuthenticationToken authentication = (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(token); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotNull(); + assertThat(updatedAuthorization.getTokens().getRefreshToken()).isEqualTo(this.authorization.getTokens().getRefreshToken()); + assertThat(authentication.getRefreshToken()).isEqualTo(this.authorization.getTokens().getRefreshToken()); + } + + @Test + public void authenticateWhenClientSetToGenerateNewRefreshTokensThenGenerateNewToken() { + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(this.authorization); + + RegisteredClient clientWithReuseTokensFalse = TestRegisteredClients.registeredClient() + .tokenSettings(tokenSettings -> tokenSettings.reuseRefreshTokens(false)) + .build(); + + OAuth2RefreshTokenAuthenticationToken token = + new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, new OAuth2ClientAuthenticationToken(clientWithReuseTokensFalse)); + + OAuth2AccessTokenAuthenticationToken authentication = (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(token); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotNull(); + assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotEqualTo(this.authorization.getTokens().getRefreshToken()); + assertThat(authentication.getRefreshToken()).isNotEqualTo(this.authorization.getTokens().getRefreshToken()); + } + + @Test + public void authenticateWhenRefreshTokenHasScopesThenIncludeScopes() { + Set requestedScopes = new HashSet<>(); + requestedScopes.add("email"); + requestedScopes.add("openid"); + + OAuth2RefreshTokenAuthenticationToken tokenWithScopes + = new OAuth2RefreshTokenAuthenticationToken(this.clientPrincipal, REFRESH_TOKEN_VALUE, requestedScopes); + + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(this.authorization); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.provider.authenticate(tokenWithScopes); + + assertThat(accessTokenAuthentication.getAccessToken()).isNotNull(); + assertThat(accessTokenAuthentication.getAccessToken().getScopes()).containsAll(requestedScopes); + } + + @Test + public void authenticateWhenRefreshTokenHasNotApprovedScopesThenThrowException() { + Set requestedScopes = new HashSet<>(); + requestedScopes.add("email"); + requestedScopes.add("another-scope"); + + OAuth2RefreshTokenAuthenticationToken tokenWithScopes + = new OAuth2RefreshTokenAuthenticationToken(this.clientPrincipal, REFRESH_TOKEN_VALUE, requestedScopes); + + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(this.authorization); + + assertThatThrownBy(() -> this.provider.authenticate(tokenWithScopes)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting((Throwable e) -> ((OAuth2AuthenticationException) e).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_SCOPE); + } + + @Test + public void authenticateWhenRefreshTokenDoesNotExistThenThrowException() { + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(null); + + OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); + assertThatThrownBy(() -> this.provider.authenticate(token)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } + + @Test + public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient.getClientId(), null); + OAuth2RefreshTokenAuthenticationToken token = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, clientPrincipal); + + Assertions.assertThatThrownBy(() -> this.provider.authenticate(token)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void authenticateWhenRefreshTokenHasExpiredThenThrowException() { + OAuth2RefreshToken expiredRefreshToken = new OAuth2RefreshToken(REFRESH_TOKEN_VALUE, Instant.now().minusSeconds(120), Instant.now().minusSeconds(60)); + OAuth2Authorization authorizationWithExpiredRefreshToken = + OAuth2Authorization + .from(this.authorization) + .tokens(OAuth2Tokens.from(this.authorization.getTokens()).refreshToken(expiredRefreshToken).build()) + .build(); + + OAuth2RefreshTokenAuthenticationToken token + = new OAuth2RefreshTokenAuthenticationToken(REFRESH_TOKEN_VALUE, this.clientPrincipal); + + when(this.authorizationService.findByToken(REFRESH_TOKEN_VALUE, TokenType.REFRESH_TOKEN)) + .thenReturn(authorizationWithExpiredRefreshToken); + + assertThatThrownBy(() -> this.provider.authenticate(token)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting((Throwable e) -> ((OAuth2AuthenticationException) e).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java new file mode 100644 index 0000000..e8912b4 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * @author Alexey Nesterov + * @since 0.0.3 + */ +public class OAuth2RefreshTokenAuthenticationTokenTests { + + @Test + public void constructorWhenClientPrincipalNullThrowException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientPrincipal cannot be null"); + } + + @Test + public void constructorWhenRefreshTokenNullOrEmptyThrowException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, mock(OAuth2ClientAuthenticationToken.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("refreshToken cannot be null or empty"); + + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", mock(OAuth2ClientAuthenticationToken.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("refreshToken cannot be null or empty"); + } + + @Test + public void constructorWhenGettingScopesThenReturnRequestedScopes() { + Set expectedScopes = new HashSet<>(Arrays.asList("scope-a", "scope-b")); + OAuth2RefreshTokenAuthenticationToken token + = new OAuth2RefreshTokenAuthenticationToken(mock(OAuth2ClientAuthenticationToken.class), "test", expectedScopes); + + assertThat(token.getScopes()).containsAll(expectedScopes); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java index cce5a3b..5528b29 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java @@ -32,8 +32,11 @@ public class TokenSettingsTests { @Test public void constructorWhenDefaultThenDefaultsAreSet() { TokenSettings tokenSettings = new TokenSettings(); - assertThat(tokenSettings.settings()).hasSize(1); + assertThat(tokenSettings.settings()).hasSize(4); assertThat(tokenSettings.accessTokenTimeToLive()).isEqualTo(Duration.ofMinutes(5)); + assertThat(tokenSettings.enableRefreshTokens()).isTrue(); + assertThat(tokenSettings.reuseRefreshTokens()).isEqualTo(false); + assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(Duration.ofMinutes(60)); } @Test @@ -50,6 +53,44 @@ public class TokenSettingsTests { assertThat(tokenSettings.accessTokenTimeToLive()).isEqualTo(accessTokenTimeToLive); } + @Test + public void enableRefreshTokenWhenFalseThenSet() { + TokenSettings tokenSettings = new TokenSettings().enableRefreshTokens(false); + assertThat(tokenSettings.enableRefreshTokens()).isFalse(); + } + + @Test + public void reuseRefreshTokensWhenProvidedThenSet() { + boolean reuseRefreshTokens = true; + TokenSettings tokenSettings = new TokenSettings().reuseRefreshTokens(reuseRefreshTokens); + assertThat(tokenSettings.reuseRefreshTokens()).isEqualTo(reuseRefreshTokens); + } + + @Test + public void refreshTokenTimeToLiveWhenProvidedThenSet() { + Duration refresTokenTimeToLive = Duration.ofDays(10); + TokenSettings tokenSettings = new TokenSettings().refreshTokenTimeToLive(refresTokenTimeToLive); + assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(refresTokenTimeToLive); + } + + @Test + public void refreshTokenTimeToLiveWhenZeroOrNegativeThenThrowException() { + assertThatThrownBy(() -> new TokenSettings().refreshTokenTimeToLive(null)) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("refreshTokenTimeToLive cannot be null"); + + assertThatThrownBy(() -> new TokenSettings().refreshTokenTimeToLive(Duration.ZERO)) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("refreshTokenTimeToLive has to be greater than Duration.ZERO"); + + assertThatThrownBy(() -> new TokenSettings().refreshTokenTimeToLive(Duration.ofSeconds(-10))) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("refreshTokenTimeToLive has to be greater than Duration.ZERO"); + } + @Test public void settingWhenCalledThenReturnTokenSettings() { Duration accessTokenTimeToLive = Duration.ofMinutes(10); @@ -57,8 +98,11 @@ public class TokenSettingsTests { .setting("name1", "value1") .accessTokenTimeToLive(accessTokenTimeToLive) .settings(settings -> settings.put("name2", "value2")); - assertThat(tokenSettings.settings()).hasSize(3); + assertThat(tokenSettings.settings()).hasSize(6); assertThat(tokenSettings.accessTokenTimeToLive()).isEqualTo(accessTokenTimeToLive); + assertThat(tokenSettings.enableRefreshTokens()).isTrue(); + assertThat(tokenSettings.reuseRefreshTokens()).isFalse(); + assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(Duration.ofMinutes(60)); assertThat(tokenSettings.setting("name1")).isEqualTo("value1"); assertThat(tokenSettings.setting("name2")).isEqualTo("value2"); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index 1a8c26e..2f880ee 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -32,6 +32,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; @@ -41,6 +42,7 @@ import org.springframework.security.oauth2.server.authorization.authentication.O import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.util.StringUtils; @@ -311,6 +313,56 @@ public class OAuth2TokenEndpointFilterTests { assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); } + @Test + public void doFilterWhenRefreshTokenRequestValidThenAccessTokenResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofHours(1)), + new HashSet<>(Arrays.asList("scope1", "scope2"))); + String refreshTokenValue = "refresh-token"; + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(refreshTokenValue, Instant.now()); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + new OAuth2AccessTokenAuthenticationToken( + registeredClient, clientPrincipal, accessToken, refreshToken); + + when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + + MockHttpServletRequest request = createRefreshTokenTokenRequest(registeredClient, refreshTokenValue, null); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + ArgumentCaptor argumentCaptor = + ArgumentCaptor.forClass(OAuth2RefreshTokenAuthenticationToken.class); + verify(this.authenticationManager).authenticate(argumentCaptor.capture()); + + OAuth2RefreshTokenAuthenticationToken refreshTokenAuthenticationToken = + argumentCaptor.getValue(); + assertThat(refreshTokenAuthenticationToken.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(refreshTokenAuthenticationToken.getScopes()).isEqualTo(registeredClient.getScopes()); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); + + OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken(); + assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType()); + assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue()); + assertThat(accessTokenResult.getIssuedAt()).isBetween( + accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1)); + assertThat(accessTokenResult.getExpiresAt()).isBetween( + accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1)); + assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); + } + private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode, MockHttpServletRequest request) throws Exception { @@ -366,4 +418,21 @@ public class OAuth2TokenEndpointFilterTests { return request; } + + private static MockHttpServletRequest createRefreshTokenTokenRequest(RegisteredClient registeredClient, String refreshToken, String scope) { + String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()); + request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, refreshToken); + if (scope == null) { + request.addParameter(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); + } else { + request.addParameter(OAuth2ParameterNames.SCOPE, scope); + } + + return request; + } } diff --git a/samples/boot/oauth2-integration/client/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2-integration/client/src/main/java/sample/config/WebClientConfig.java index 06f86aa..443986f 100644 --- a/samples/boot/oauth2-integration/client/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2-integration/client/src/main/java/sample/config/WebClientConfig.java @@ -51,6 +51,7 @@ public class WebClientConfig { OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .clientCredentials() + .refreshToken() .build(); DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( clientRegistrationRepository, authorizedClientRepository);