From adf96b4e25808fcb51303309648bb03db08ebb8f Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 18 Jan 2021 09:31:06 -0500 Subject: [PATCH] Add OAuth2TokenCustomizer Closes gh-199 --- .../OAuth2AuthorizationServerConfigurer.java | 35 +++++ .../security/oauth2/core/context/Context.java | 47 +++++++ .../oauth2/core/context/DefaultContext.java | 50 +++++++ .../security/oauth2/jwt/NimbusJwsEncoder.java | 51 ++------ .../OAuth2AuthorizationAttributeNames.java | 7 +- .../JwtEncodingContextUtils.java | 123 ++++++++++++++++++ ...thorizationCodeAuthenticationProvider.java | 63 ++++++--- ...ientCredentialsAuthenticationProvider.java | 41 ++++-- ...th2RefreshTokenAuthenticationProvider.java | 53 ++++++-- .../authentication/OAuth2TokenIssuerUtil.java | 97 -------------- .../token/JwtEncodingContext.java | 87 +++++++++++++ .../token/OAuth2TokenContext.java | 119 +++++++++++++++++ .../token/OAuth2TokenCustomizer.java | 28 ++++ .../OAuth2AuthorizationEndpointFilter.java | 34 ++--- .../OAuth2AuthorizationCodeGrantTests.java | 67 ++++++++-- .../OAuth2ClientCredentialsGrantTests.java | 13 ++ .../OAuth2RefreshTokenGrantTests.java | 56 +++++++- .../server/authorization/OidcTests.java | 63 ++++++++- .../oauth2/jwt/NimbusJwsEncoderTests.java | 24 ---- .../TestOAuth2Authorizations.java | 15 ++- ...zationCodeAuthenticationProviderTests.java | 61 ++++++++- ...redentialsAuthenticationProviderTests.java | 69 +++++++--- ...freshTokenAuthenticationProviderTests.java | 59 +++++++-- .../token/JwtEncodingContextTests.java | 118 +++++++++++++++++ ...Auth2AuthorizationEndpointFilterTests.java | 25 ++-- 25 files changed, 1126 insertions(+), 279 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/Context.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/DefaultContext.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtEncodingContextUtils.java delete mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContext.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenContext.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenCustomizer.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContextTests.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index dedb682..9627d33 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -33,7 +33,9 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -152,23 +154,33 @@ public final class OAuth2AuthorizationServerConfigurer jwtCustomizer = getJwtCustomizer(builder); OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( getAuthorizationService(builder), jwtEncoder); + if (jwtCustomizer != null) { + authorizationCodeAuthenticationProvider.setJwtCustomizer(jwtCustomizer); + } builder.authenticationProvider(postProcess(authorizationCodeAuthenticationProvider)); OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider = new OAuth2RefreshTokenAuthenticationProvider( getAuthorizationService(builder), jwtEncoder); + if (jwtCustomizer != null) { + refreshTokenAuthenticationProvider.setJwtCustomizer(jwtCustomizer); + } builder.authenticationProvider(postProcess(refreshTokenAuthenticationProvider)); OAuth2ClientCredentialsAuthenticationProvider clientCredentialsAuthenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider( getAuthorizationService(builder), jwtEncoder); + if (jwtCustomizer != null) { + clientCredentialsAuthenticationProvider.setJwtCustomizer(jwtCustomizer); + } builder.authenticationProvider(postProcess(clientCredentialsAuthenticationProvider)); OAuth2TokenRevocationAuthenticationProvider tokenRevocationAuthenticationProvider = @@ -314,6 +326,19 @@ public final class OAuth2AuthorizationServerConfigurer> OAuth2TokenCustomizer getJwtCustomizer(B builder) { + OAuth2TokenCustomizer jwtCustomizer = builder.getSharedObject(OAuth2TokenCustomizer.class); + if (jwtCustomizer == null) { + ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, JwtEncodingContext.class); + jwtCustomizer = getOptionalBean(builder, type); + if (jwtCustomizer != null) { + builder.setSharedObject(OAuth2TokenCustomizer.class, jwtCustomizer); + } + } + return jwtCustomizer; + } + private static > ProviderSettings getProviderSettings(B builder) { ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class); if (providerSettings == null) { @@ -353,4 +378,14 @@ public final class OAuth2AuthorizationServerConfigurer, T> T getOptionalBean(B builder, ResolvableType type) { + ApplicationContext context = builder.getSharedObject(ApplicationContext.class); + String[] names = context.getBeanNamesForType(type); + if (names.length > 1) { + throw new NoUniqueBeanDefinitionException(type, names); + } + return names.length == 1 ? (T) context.getBean(names[0]) : null; + } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/Context.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/Context.java new file mode 100644 index 0000000..9ae5e82 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/Context.java @@ -0,0 +1,47 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.core.context; + +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * A facility for holding information associated to a specific context. + * + * @author Joe Grandja + * @since 0.1.0 + */ +public interface Context { + + @Nullable + V get(Object key); + + @Nullable + default V get(Class key) { + Assert.notNull(key, "key cannot be null"); + V value = get((Object) key); + return key.isInstance(value) ? value : null; + } + + boolean hasKey(Object key); + + static Context of(Map context) { + return new DefaultContext(context); + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/DefaultContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/DefaultContext.java new file mode 100644 index 0000000..7a69a5e --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/DefaultContext.java @@ -0,0 +1,50 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.core.context; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * @author Joe Grandja + * @since 0.1.0 + */ +final class DefaultContext implements Context { + private final Map context; + + DefaultContext(Map context) { + Assert.notNull(context, "context cannot be null"); + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + } + + @SuppressWarnings("unchecked") + @Override + @Nullable + public V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java index 9374c9d..1a66063 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java @@ -23,8 +23,6 @@ import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; import java.util.stream.Collectors; import com.nimbusds.jose.JOSEException; @@ -46,7 +44,6 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import org.springframework.core.convert.converter.Converter; -import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -88,9 +85,6 @@ public final class NimbusJwsEncoder implements JwtEncoder { private final JWKSource jwkSource; - private BiConsumer jwtCustomizer = (headers, claims) -> { - }; - /** * Constructs a {@code NimbusJwsEncoder} using the provided parameters. * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} @@ -100,32 +94,12 @@ public final class NimbusJwsEncoder implements JwtEncoder { this.jwkSource = jwkSource; } - /** - * Sets the {@link Jwt} customizer to be provided the {@link JoseHeader.Builder} and - * {@link JwtClaimsSet.Builder} allowing for further customizations. - * @param jwtCustomizer the {@link Jwt} customizer to be provided the - * {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder} - */ - public void setJwtCustomizer(BiConsumer jwtCustomizer) { - Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); - this.jwtCustomizer = jwtCustomizer; - } - @Override public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { Assert.notNull(headers, "headers cannot be null"); Assert.notNull(claims, "claims cannot be null"); - // @formatter:off - JoseHeader.Builder headersBuilder = JoseHeader.from(headers) - .type(JOSEObjectType.JWT.getType()); - JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims) - .id(UUID.randomUUID().toString()); - // @formatter:on - - this.jwtCustomizer.accept(headersBuilder, claimsBuilder); - - JWK jwk = selectJwk(headersBuilder); + JWK jwk = selectJwk(headers); if (jwk == null) { throw new JwtEncodingException( String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); @@ -135,8 +109,15 @@ public final class NimbusJwsEncoder implements JwtEncoder { "The \"kid\" (key ID) from the selected JWK cannot be empty")); } - headers = headersBuilder.keyId(jwk.getKeyID()).build(); - claims = claimsBuilder.build(); + // @formatter:off + headers = JoseHeader.from(headers) + .type(JOSEObjectType.JWT.getType()) + .keyId(jwk.getKeyID()) + .build(); + claims = JwtClaimsSet.from(claims) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers); JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); @@ -164,13 +145,9 @@ public final class NimbusJwsEncoder implements JwtEncoder { return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims()); } - private JWK selectJwk(JoseHeader.Builder headersBuilder) { - final AtomicReference jwsAlgorithm = new AtomicReference<>(); - headersBuilder.headers((h) -> { - JwsAlgorithm jwsAlg = (JwsAlgorithm) h.get(JoseHeaderNames.ALG); - jwsAlgorithm.set(JWSAlgorithm.parse(jwsAlg.getName())); - }); - JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm.get()); + private JWK selectJwk(JoseHeader headers) { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getJwsAlgorithm().getName()); + JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm); JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader)); List jwks; @@ -184,7 +161,7 @@ public final class NimbusJwsEncoder implements JwtEncoder { if (jwks.size() > 1) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.get().getName() + "'")); + "Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.getName() + "'")); } return !jwks.isEmpty() ? jwks.get(0) : null; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java index 364070c..4f661ef 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,4 +56,9 @@ public interface OAuth2AuthorizationAttributeNames { */ String ACCESS_TOKEN_ATTRIBUTES = OAuth2Authorization.class.getName().concat(".ACCESS_TOKEN_ATTRIBUTES"); + /** + * The name of the attribute used for the resource owner {@code Principal}. + */ + String PRINCIPAL = OAuth2Authorization.class.getName().concat(".PRINCIPAL"); + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtEncodingContextUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtEncodingContextUtils.java new file mode 100644 index 0000000..4ba5ebc --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtEncodingContextUtils.java @@ -0,0 +1,123 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Set; + +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * @author Joe Grandja + * @since 0.1.0 + */ +final class JwtEncodingContextUtils { + + private JwtEncodingContextUtils() { + } + + static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization) { + // @formatter:off + return accessTokenContext(registeredClient, authorization, + authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES)); + // @formatter:on + } + + static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization, + Set authorizedScopes) { + // @formatter:off + return accessTokenContext(registeredClient, authorization.getPrincipalName(), authorizedScopes) + .authorization(authorization); + // @formatter:on + } + + static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient, + String principalName, Set authorizedScopes) { + + JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256); + + String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().accessTokenTimeToLive()); + + // @formatter:off + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder() + .issuer(issuer) + .subject(principalName) + .audience(Collections.singletonList(registeredClient.getClientId())) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .notBefore(issuedAt); + if (!CollectionUtils.isEmpty(authorizedScopes)) { + claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes); + } + // @formatter:on + + // @formatter:off + return JwtEncodingContext.with(headersBuilder, claimsBuilder) + .registeredClient(registeredClient) + .tokenType(TokenType.ACCESS_TOKEN); + // @formatter:on + } + + static JwtEncodingContext.Builder idTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization) { + JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256); + + String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); // TODO Allow configuration for id token time-to-live + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE); + + // @formatter:off + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder() + .issuer(issuer) + .subject(authorization.getPrincipalName()) + .audience(Collections.singletonList(registeredClient.getClientId())) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .claim(IdTokenClaimNames.AZP, registeredClient.getClientId()); + if (StringUtils.hasText(nonce)) { + claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce); + } + // TODO Add 'auth_time' claim + // @formatter:on + + // @formatter:off + return JwtEncodingContext.with(headersBuilder, claimsBuilder) + .registeredClient(registeredClient) + .authorization(authorization) + .tokenType(new TokenType(OidcParameterNames.ID_TOKEN)); + // @formatter:on + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 1204df2..e4ceafe 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,10 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -25,27 +29,27 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; /** @@ -58,12 +62,15 @@ import static org.springframework.security.oauth2.server.authorization.authentic * @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AuthorizationService * @see JwtEncoder + * @see OAuth2TokenCustomizer + * @see JwtEncodingContext * @see Section 4.1 Authorization Code Grant * @see Section 4.1.3 Access Token Request */ public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { private final OAuth2AuthorizationService authorizationService; private final JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer = (context) -> {}; /** * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. @@ -78,6 +85,11 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica this.jwtEncoder = jwtEncoder; } + public final void setJwtCustomizer(OAuth2TokenCustomizer jwtCustomizer) { + Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); + this.jwtCustomizer = jwtCustomizer; + } + @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = @@ -116,27 +128,46 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } - Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); - Jwt jwt = OAuth2TokenIssuerUtil.issueJwtAccessToken( - this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), - authorizedScopes, registeredClient.getTokenSettings().accessTokenTimeToLive()); - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), authorizedScopes); + // @formatter:off + JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, authorization) + .principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrant(authorizationCodeAuthentication) + .build(); + // @formatter:on + this.jwtCustomizer.customize(context); + JoseHeader headers = context.getHeaders().build(); + JwtClaimsSet claims = context.getClaims().build(); + Jwt jwt = this.jwtEncoder.encode(headers, claims); + + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens()) .accessToken(accessToken); OAuth2RefreshToken refreshToken = null; if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) { - refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(registeredClient.getTokenSettings().refreshTokenTimeToLive()); + refreshToken = OAuth2RefreshTokenAuthenticationProvider.generateRefreshToken( + registeredClient.getTokenSettings().refreshTokenTimeToLive()); tokensBuilder.refreshToken(refreshToken); } OidcIdToken idToken = null; if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) { - Jwt jwtIdToken = OAuth2TokenIssuerUtil.issueIdToken( - this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), - (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE)); + // @formatter:off + context = JwtEncodingContextUtils.idTokenContext(registeredClient, authorization) + .principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrant(authorizationCodeAuthentication) + .build(); + // @formatter:on + this.jwtCustomizer.customize(context); + + headers = context.getHeaders().build(); + claims = context.getClaims().build(); + Jwt jwtIdToken = this.jwtEncoder.encode(headers, claims); + idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(), jwtIdToken.getExpiresAt(), jwtIdToken.getClaims()); tokensBuilder.token(idToken); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index bf948db..7891923 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,10 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; +import java.util.LinkedHashSet; +import java.util.Set; +import java.util.stream.Collectors; + import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -23,37 +27,42 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.util.LinkedHashSet; -import java.util.Set; -import java.util.stream.Collectors; - import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; /** * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Client Credentials Grant. * * @author Alexey Nesterov + * @author Joe Grandja * @since 0.0.1 * @see OAuth2ClientCredentialsAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AuthorizationService * @see JwtEncoder + * @see OAuth2TokenCustomizer + * @see JwtEncodingContext * @see Section 4.4 Client Credentials Grant * @see Section 4.4.2 Access Token Request */ public class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider { private final OAuth2AuthorizationService authorizationService; private final JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer = (context) -> {}; /** * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters. @@ -69,6 +78,11 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica this.jwtEncoder = jwtEncoder; } + public final void setJwtCustomizer(OAuth2TokenCustomizer jwtCustomizer) { + Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); + this.jwtCustomizer = jwtCustomizer; + } + @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication = @@ -93,10 +107,21 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes()); } - Jwt jwt = OAuth2TokenIssuerUtil - .issueJwtAccessToken(this.jwtEncoder, clientPrincipal.getName(), registeredClient.getClientId(), scopes, registeredClient.getTokenSettings().accessTokenTimeToLive()); + // @formatter:off + JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, clientPrincipal.getName(), scopes) + .principal(clientPrincipal) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .authorizationGrant(clientCredentialsAuthentication) + .build(); + // @formatter:on + this.jwtCustomizer.customize(context); + + JoseHeader headers = context.getHeaders().build(); + JwtClaimsSet claims = context.getClaims().build(); + Jwt jwt = this.jwtEncoder.encode(headers, claims); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes); + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(clientPrincipal.getName()) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java index bdccece..ece2389 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,47 +15,62 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; +import java.util.Set; + import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken2; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.config.TokenSettings; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.util.Assert; -import java.time.Instant; -import java.util.Set; - import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; /** * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant. * * @author Alexey Nesterov + * @author Joe Grandja * @since 0.0.3 * @see OAuth2RefreshTokenAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AuthorizationService * @see JwtEncoder + * @see OAuth2TokenCustomizer + * @see JwtEncodingContext * @see Section 1.5 Refresh Token Grant * @see Section 6 Refreshing an Access Token */ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider { + private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); private final OAuth2AuthorizationService authorizationService; private final JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer = (context) -> {}; /** * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters. @@ -71,6 +86,11 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP this.jwtEncoder = jwtEncoder; } + public final void setJwtCustomizer(OAuth2TokenCustomizer jwtCustomizer) { + Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); + this.jwtCustomizer = jwtCustomizer; + } + @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2RefreshTokenAuthenticationToken refreshTokenAuthentication = @@ -121,15 +141,26 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } - Jwt jwt = OAuth2TokenIssuerUtil - .issueJwtAccessToken(this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), scopes, registeredClient.getTokenSettings().accessTokenTimeToLive()); + // @formatter:off + JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, authorization, scopes) + .principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) + .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) + .authorizationGrant(refreshTokenAuthentication) + .build(); + // @formatter:on + this.jwtCustomizer.customize(context); + + JoseHeader headers = context.getHeaders().build(); + JwtClaimsSet claims = context.getClaims().build(); + Jwt jwt = this.jwtEncoder.encode(headers, claims); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes); + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); TokenSettings tokenSettings = registeredClient.getTokenSettings(); if (!tokenSettings.reuseRefreshTokens()) { - refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(tokenSettings.refreshTokenTimeToLive()); + refreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive()); } authorization = OAuth2Authorization.from(authorization) @@ -146,4 +177,10 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP public boolean supports(Class authentication) { return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication); } + + static OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(tokenTimeToLive); + return new OAuth2RefreshToken2(TOKEN_GENERATOR.generateKey(), issuedAt, expiresAt); + } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java deleted file mode 100644 index a5a12ce..0000000 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2020-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.authentication; - -import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; -import org.springframework.security.crypto.keygen.StringKeyGenerator; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import org.springframework.security.oauth2.core.OAuth2RefreshToken2; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.jwt.JoseHeader; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; -import org.springframework.security.oauth2.jwt.JwtEncoder; -import org.springframework.util.StringUtils; - -import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Base64; -import java.util.Collections; -import java.util.Set; - -/** - * @author Alexey Nesterov - * @since 0.0.3 - */ -class OAuth2TokenIssuerUtil { - - private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); - - static Jwt issueJwtAccessToken(JwtEncoder jwtEncoder, String subject, String audience, Set scopes, Duration tokenTimeToLive) { - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); - - String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim - Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(tokenTimeToLive); - - JwtClaimsSet jwtClaimsSet = JwtClaimsSet.builder() - .issuer(issuer) - .subject(subject) - .audience(Collections.singletonList(audience)) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .notBefore(issuedAt) - .claim(OAuth2ParameterNames.SCOPE, scopes) - .build(); - - return jwtEncoder.encode(joseHeader, jwtClaimsSet); - } - - static Jwt issueIdToken(JwtEncoder jwtEncoder, String subject, String audience, String nonce) { - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); - - String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim - Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); // TODO Allow configuration for id token time-to-live - - JwtClaimsSet.Builder builder = JwtClaimsSet.builder() - .issuer(issuer) - .subject(subject) - .audience(Collections.singletonList(audience)) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .claim(IdTokenClaimNames.AZP, audience); - if (StringUtils.hasText(nonce)) { - builder.claim(IdTokenClaimNames.NONCE, nonce); - } - - // TODO Add 'auth_time' claim - - JwtClaimsSet jwtClaimsSet = builder.build(); - - return jwtEncoder.encode(joseHeader, jwtClaimsSet); - } - - static OAuth2RefreshToken issueRefreshToken(Duration tokenTimeToLive) { - Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(tokenTimeToLive); - - return new OAuth2RefreshToken2(TOKEN_GENERATOR.generateKey(), issuedAt, expiresAt); - } -} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContext.java new file mode 100644 index 0000000..78df287 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContext.java @@ -0,0 +1,87 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.context.Context; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.util.Assert; + +/** + * @author Joe Grandja + * @since 0.1.0 + * @see OAuth2TokenContext + * @see JoseHeader.Builder + * @see JwtClaimsSet.Builder + */ +public final class JwtEncodingContext implements OAuth2TokenContext { + private final Context context; + + private JwtEncodingContext(Map context) { + this.context = Context.of(context); + } + + @Nullable + @Override + public V get(Object key) { + return this.context.get(key); + } + + @Override + public boolean hasKey(Object key) { + return this.context.hasKey(key); + } + + public JoseHeader.Builder getHeaders() { + return get(JoseHeader.Builder.class); + } + + public JwtClaimsSet.Builder getClaims() { + return get(JwtClaimsSet.Builder.class); + } + + public static Builder with(JoseHeader.Builder headersBuilder, JwtClaimsSet.Builder claimsBuilder) { + return new Builder(headersBuilder, claimsBuilder); + } + + public static final class Builder extends AbstractBuilder { + + private Builder(JoseHeader.Builder headersBuilder, JwtClaimsSet.Builder claimsBuilder) { + Assert.notNull(headersBuilder, "headersBuilder cannot be null"); + Assert.notNull(claimsBuilder, "claimsBuilder cannot be null"); + put(JoseHeader.Builder.class, headersBuilder); + put(JwtClaimsSet.Builder.class, claimsBuilder); + } + + public Builder headers(Consumer headersConsumer) { + headersConsumer.accept(get(JoseHeader.Builder.class)); + return this; + } + + public Builder claims(Consumer claimsConsumer) { + claimsConsumer.accept(get(JwtClaimsSet.Builder.class)); + return this; + } + + public JwtEncodingContext build() { + return new JwtEncodingContext(this.context); + } + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenContext.java new file mode 100644 index 0000000..c57448d --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenContext.java @@ -0,0 +1,119 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.context.Context; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; + +/** + * @author Joe Grandja + * @since 0.1.0 + * @see Context + */ +public interface OAuth2TokenContext extends Context { + + default RegisteredClient getRegisteredClient() { + return get(RegisteredClient.class); + } + + default T getPrincipal() { + return get(AbstractBuilder.PRINCIPAL_AUTHENTICATION_KEY); + } + + @Nullable + default OAuth2Authorization getAuthorization() { + return get(OAuth2Authorization.class); + } + + default TokenType getTokenType() { + return get(TokenType.class); + } + + default AuthorizationGrantType getAuthorizationGrantType() { + return get(AuthorizationGrantType.class); + } + + default T getAuthorizationGrant() { + return get(AbstractBuilder.AUTHORIZATION_GRANT_AUTHENTICATION_KEY); + } + + abstract class AbstractBuilder> { + private static final String PRINCIPAL_AUTHENTICATION_KEY = + Authentication.class.getName().concat(".PRINCIPAL"); + private static final String AUTHORIZATION_GRANT_AUTHENTICATION_KEY = + Authentication.class.getName().concat(".AUTHORIZATION_GRANT"); + protected final Map context = new HashMap<>(); + + public B registeredClient(RegisteredClient registeredClient) { + return put(RegisteredClient.class, registeredClient); + } + + public B principal(Authentication principal) { + return put(PRINCIPAL_AUTHENTICATION_KEY, principal); + } + + public B authorization(OAuth2Authorization authorization) { + return put(OAuth2Authorization.class, authorization); + } + + public B tokenType(TokenType tokenType) { + return put(TokenType.class, tokenType); + } + + public B authorizationGrantType(AuthorizationGrantType authorizationGrantType) { + return put(AuthorizationGrantType.class, authorizationGrantType); + } + + public B authorizationGrant(Authentication authorizationGrant) { + return put(AUTHORIZATION_GRANT_AUTHENTICATION_KEY, authorizationGrant); + } + + public B put(Object key, Object value) { + Assert.notNull(key, "key cannot be null"); + Assert.notNull(value, "value cannot be null"); + this.context.put(key, value); + return getThis(); + } + + public B context(Consumer> contextConsumer) { + contextConsumer.accept(this.context); + return getThis(); + } + + @SuppressWarnings("unchecked") + protected V get(Object key) { + return (V) this.context.get(key); + } + + @SuppressWarnings("unchecked") + protected B getThis() { + return (B) this; + } + + public abstract T build(); + + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenCustomizer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenCustomizer.java new file mode 100644 index 0000000..0b3ffe4 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenCustomizer.java @@ -0,0 +1,28 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +/** + * @author Joe Grandja + * @since 0.1.0 + * @see OAuth2TokenContext + */ +@FunctionalInterface +public interface OAuth2TokenCustomizer { + + void customize(C context); + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 27a31f2..1efb01b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,22 @@ */ package org.springframework.security.oauth2.server.authorization.web; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -53,21 +69,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Base64; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - /** * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, * which handles the processing of the OAuth 2.0 Authorization Request. @@ -193,6 +194,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest(); OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(principal.getName()) + .attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, principal) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest); if (registeredClient.getClientSettings().requireUserConsent()) { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 274ef99..8117a90 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -18,7 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.Base64; -import java.util.function.BiConsumer; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; @@ -33,19 +35,29 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.jose.TestJwks; -import org.springframework.security.oauth2.jwt.JoseHeader; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TokenType; @@ -53,7 +65,9 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; @@ -90,13 +104,16 @@ public class OAuth2AuthorizationCodeGrantTests { // https://tools.ietf.org/html/rfc7636#appendix-B private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + private static final String AUTHORITIES_CLAIM = "authorities"; private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static JWKSource jwkSource; private static NimbusJwsEncoder jwtEncoder; - private static BiConsumer jwtCustomizer; + private static NimbusJwtDecoder jwtDecoder; private static ProviderSettings providerSettings; + private static HttpMessageConverter accessTokenHttpResponseConverter = + new OAuth2AccessTokenResponseHttpMessageConverter(); @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -111,8 +128,7 @@ public class OAuth2AuthorizationCodeGrantTests { JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); jwtEncoder = new NimbusJwsEncoder(jwkSource); - jwtCustomizer = mock(BiConsumer.class); - jwtEncoder.setJwtCustomizer(jwtCustomizer); + jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build(); providerSettings = new ProviderSettings() .authorizationEndpoint("/test/authorize") .tokenEndpoint("/test/token"); @@ -186,8 +202,17 @@ public class OAuth2AuthorizationCodeGrantTests { eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - assertTokenRequestReturnsAccessTokenResponse( + OAuth2AccessTokenResponse accessTokenResponse = assertTokenRequestReturnsAccessTokenResponse( registeredClient, authorization, OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI); + + // Assert user authorities was propagated as claim in JWT + Jwt jwt = jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue()); + List authoritiesClaim = jwt.getClaim(AUTHORITIES_CLAIM); + Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL); + Set userAuthorities = principal.getAuthorities().stream() + .map(GrantedAuthority::getAuthority) + .collect(Collectors.toSet()); + assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities); } @Test @@ -208,10 +233,10 @@ public class OAuth2AuthorizationCodeGrantTests { registeredClient, authorization, providerSettings.tokenEndpoint()); } - private void assertTokenRequestReturnsAccessTokenResponse(RegisteredClient registeredClient, + private OAuth2AccessTokenResponse assertTokenRequestReturnsAccessTokenResponse(RegisteredClient registeredClient, OAuth2Authorization authorization, String tokenEndpointUri) throws Exception { - this.mvc.perform(post(tokenEndpointUri) + MvcResult mvcResult = this.mvc.perform(post(tokenEndpointUri) .params(getTokenRequestParameters(registeredClient, authorization)) .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( registeredClient.getClientId(), registeredClient.getClientSecret()))) @@ -222,13 +247,19 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(jsonPath("$.token_type").isNotEmpty()) .andExpect(jsonPath("$.expires_in").isNotEmpty()) .andExpect(jsonPath("$.refresh_token").isNotEmpty()) - .andExpect(jsonPath("$.scope").isNotEmpty()); + .andExpect(jsonPath("$.scope").isNotEmpty()) + .andReturn(); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService).findByToken( eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(TokenType.AUTHORIZATION_CODE)); verify(authorizationService).save(any()); + + MockHttpServletResponse servletResponse = mvcResult.getResponse(); + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus())); + return accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); } @Test @@ -295,8 +326,6 @@ public class OAuth2AuthorizationCodeGrantTests { .params(getTokenRequestParameters(registeredClient, authorization)) .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( registeredClient.getClientId(), registeredClient.getClientSecret()))); - - verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class)); } private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { @@ -345,6 +374,20 @@ public class OAuth2AuthorizationCodeGrantTests { JWKSource jwkSource() { return jwkSource; } + + @Bean + OAuth2TokenCustomizer jwtCustomizer() { + return context -> { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) && + TokenType.ACCESS_TOKEN.equals(context.getTokenType())) { + Authentication principal = context.getPrincipal(); + Set authorities = principal.getAuthorities().stream() + .map(GrantedAuthority::getAuthority) + .collect(Collectors.toSet()); + context.getClaims().claim(AUTHORITIES_CLAIM, authorities); + } + }; + } } @EnableWebSecurity diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java index cdea86d..2b159ea 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java @@ -41,6 +41,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; @@ -65,6 +67,7 @@ public class OAuth2ClientCredentialsGrantTests { private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static JWKSource jwkSource; + private static OAuth2TokenCustomizer jwtCustomizer; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -78,10 +81,13 @@ public class OAuth2ClientCredentialsGrantTests { authorizationService = mock(OAuth2AuthorizationService.class); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); + jwtCustomizer = mock(OAuth2TokenCustomizer.class); } + @SuppressWarnings("unchecked") @Before public void setup() { + reset(jwtCustomizer); reset(registeredClientRepository); reset(authorizationService); } @@ -115,6 +121,7 @@ public class OAuth2ClientCredentialsGrantTests { .andExpect(jsonPath("$.access_token").isNotEmpty()) .andExpect(jsonPath("$.scope").value("scope1 scope2")); + verify(jwtCustomizer).customize(any()); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService).save(any()); } @@ -136,6 +143,7 @@ public class OAuth2ClientCredentialsGrantTests { .andExpect(jsonPath("$.access_token").isNotEmpty()) .andExpect(jsonPath("$.scope").value("scope1 scope2")); + verify(jwtCustomizer).customize(any()); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService).save(any()); } @@ -166,5 +174,10 @@ public class OAuth2ClientCredentialsGrantTests { JWKSource jwkSource() { return jwkSource; } + + @Bean + OAuth2TokenCustomizer jwtCustomizer() { + return jwtCustomizer; + } } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java index 84994e1..f586e57 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java @@ -18,6 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.Base64; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; @@ -31,24 +34,40 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -68,9 +87,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @since 0.0.3 */ public class OAuth2RefreshTokenGrantTests { + private static final String AUTHORITIES_CLAIM = "authorities"; private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static JWKSource jwkSource; + private static NimbusJwtDecoder jwtDecoder; + private static HttpMessageConverter accessTokenHttpResponseConverter = + new OAuth2AccessTokenResponseHttpMessageConverter(); @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -84,6 +107,7 @@ public class OAuth2RefreshTokenGrantTests { authorizationService = mock(OAuth2AuthorizationService.class); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); + jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build(); } @Before @@ -106,7 +130,7 @@ public class OAuth2RefreshTokenGrantTests { eq(TokenType.REFRESH_TOKEN))) .thenReturn(authorization); - this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + MvcResult mvcResult = this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) .params(getRefreshTokenRequestParameters(authorization)) .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( registeredClient.getClientId(), registeredClient.getClientSecret()))) @@ -117,7 +141,8 @@ public class OAuth2RefreshTokenGrantTests { .andExpect(jsonPath("$.token_type").isNotEmpty()) .andExpect(jsonPath("$.expires_in").isNotEmpty()) .andExpect(jsonPath("$.refresh_token").isNotEmpty()) - .andExpect(jsonPath("$.scope").isNotEmpty()); + .andExpect(jsonPath("$.scope").isNotEmpty()) + .andReturn(); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService).findByToken( @@ -125,6 +150,20 @@ public class OAuth2RefreshTokenGrantTests { eq(TokenType.REFRESH_TOKEN)); verify(authorizationService).save(any()); + MockHttpServletResponse servletResponse = mvcResult.getResponse(); + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus())); + OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read( + OAuth2AccessTokenResponse.class, httpResponse); + + // Assert user authorities was propagated as claim in JWT + Jwt jwt = jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue()); + List authoritiesClaim = jwt.getClaim(AUTHORITIES_CLAIM); + Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL); + Set userAuthorities = principal.getAuthorities().stream() + .map(GrantedAuthority::getAuthority) + .collect(Collectors.toSet()); + assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities); } private static MultiValueMap getRefreshTokenRequestParameters(OAuth2Authorization authorization) { @@ -160,5 +199,18 @@ public class OAuth2RefreshTokenGrantTests { JWKSource jwkSource() { return jwkSource; } + + @Bean + OAuth2TokenCustomizer jwtCustomizer() { + return context -> { + if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) { + Authentication principal = context.getPrincipal(); + Set authorities = principal.getAuthorities().stream() + .map(GrantedAuthority::getAuthority) + .collect(Collectors.toSet()); + context.getClaims().claim(AUTHORITIES_CLAIM, authorities); + } + }; + } } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java index d081b03..0b91dc3 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java @@ -18,6 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.Base64; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; @@ -32,15 +35,28 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -48,7 +64,9 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcProviderConfigurationEndpointFilter; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; @@ -80,10 +98,14 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Daniel Garnier-Moiroux */ public class OidcTests { - private static final String issuerUrl = "https://example.com/issuer1"; + private static final String ISSUER_URL = "https://example.com/issuer1"; + private static final String AUTHORITIES_CLAIM = "authorities"; private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static JWKSource jwkSource; + private static NimbusJwtDecoder jwtDecoder; + private static HttpMessageConverter accessTokenHttpResponseConverter = + new OAuth2AccessTokenResponseHttpMessageConverter(); @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -97,6 +119,7 @@ public class OidcTests { authorizationService = mock(OAuth2AuthorizationService.class); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); + jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build(); } @Before @@ -111,7 +134,7 @@ public class OidcTests { this.mvc.perform(get(OidcProviderConfigurationEndpointFilter.DEFAULT_OIDC_PROVIDER_CONFIGURATION_ENDPOINT_URI)) .andExpect(status().is2xxSuccessful()) - .andExpect(jsonPath("issuer").value(issuerUrl)); + .andExpect(jsonPath("issuer").value(ISSUER_URL)); } @Test @@ -148,7 +171,7 @@ public class OidcTests { MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI) .params(getAuthorizationRequestParameters(registeredClient)) - .with(user("user"))) + .with(user("user").roles("A", "B"))) .andExpect(status().is3xxRedirection()) .andReturn(); assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); @@ -164,7 +187,7 @@ public class OidcTests { eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); - this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + mvcResult = this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) .params(getTokenRequestParameters(registeredClient, authorization)) .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( registeredClient.getClientId(), registeredClient.getClientSecret()))) @@ -176,13 +199,28 @@ public class OidcTests { .andExpect(jsonPath("$.expires_in").isNotEmpty()) .andExpect(jsonPath("$.refresh_token").isNotEmpty()) .andExpect(jsonPath("$.scope").isNotEmpty()) - .andExpect(jsonPath("$.id_token").isNotEmpty()); + .andExpect(jsonPath("$.id_token").isNotEmpty()) + .andReturn(); verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); verify(authorizationService).findByToken( eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(TokenType.AUTHORIZATION_CODE)); verify(authorizationService, times(2)).save(any()); + + MockHttpServletResponse servletResponse = mvcResult.getResponse(); + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus())); + OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); + + // Assert user authorities was propagated as claim in ID Token + Jwt idToken = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN)); + List authoritiesClaim = idToken.getClaim(AUTHORITIES_CLAIM); + Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL); + Set userAuthorities = principal.getAuthorities().stream() + .map(GrantedAuthority::getAuthority) + .collect(Collectors.toSet()); + assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities); } private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { @@ -231,6 +269,19 @@ public class OidcTests { JWKSource jwkSource() { return jwkSource; } + + @Bean + OAuth2TokenCustomizer jwtCustomizer() { + return context -> { + if (context.getTokenType().getValue().equals(OidcParameterNames.ID_TOKEN)) { + Authentication principal = context.getPrincipal(); + Set authorities = principal.getAuthorities().stream() + .map(GrantedAuthority::getAuthority) + .collect(Collectors.toSet()); + context.getClaims().claim(AUTHORITIES_CLAIM, authorities); + } + }; + } } @EnableWebSecurity @@ -239,7 +290,7 @@ public class OidcTests { @Bean ProviderSettings providerSettings() { - return new ProviderSettings().issuer(issuerUrl); + return new ProviderSettings().issuer(ISSUER_URL); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java index b00f904..e916171 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java @@ -21,7 +21,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.function.BiConsumer; import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.jwk.ECKey; @@ -49,7 +48,6 @@ import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; /** * Tests for {@link NimbusJwsEncoder}. @@ -77,12 +75,6 @@ public class NimbusJwsEncoderTests { .withMessage("jwkSource cannot be null"); } - @Test - public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.setJwtCustomizer(null)) - .withMessage("jwtCustomizer cannot be null"); - } - @Test public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); @@ -99,22 +91,6 @@ public class NimbusJwsEncoderTests { .withMessage("claims cannot be null"); } - @Test - public void encodeWhenCustomizerSetThenCalled() { - RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; - this.jwkList.add(rsaJwk); - - BiConsumer jwtCustomizer = mock(BiConsumer.class); - this.jwsEncoder.setJwtCustomizer(jwtCustomizer); - - JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); - JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - - this.jwsEncoder.encode(joseHeader, jwtClaimsSet); - - verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class)); - } - @Test public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception { this.jwkSource = mock(JWKSource.class); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 260a6b0..f0fa05d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,12 @@ */ package org.springframework.security.oauth2.server.authorization; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Map; + +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; @@ -24,11 +30,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Collections; -import java.util.Map; - /** * @author Joe Grandja * @author Daniel Garnier-Moiroux @@ -63,6 +64,8 @@ public class TestOAuth2Authorizations { .principalName("principal") .tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build()) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) + .attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, + new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B")) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 5c7d79f..e781610 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -15,10 +15,17 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Set; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -28,11 +35,12 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; -import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -41,14 +49,10 @@ import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; -import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Set; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.entry; @@ -69,6 +73,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { private static final String AUTHORIZATION_CODE = "code"; private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer; private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; @Before @@ -77,6 +82,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { this.jwtEncoder = mock(JwtEncoder.class); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( this.authorizationService, this.jwtEncoder); + this.jwtCustomizer = mock(OAuth2TokenCustomizer.class); + this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer); } @Test @@ -93,6 +100,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { .hasMessage("jwtEncoder cannot be null"); } + @Test + public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtCustomizer cannot be null"); + } + @Test public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue(); @@ -225,6 +239,18 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); + verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture()); + JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue(); + assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(jwtEncodingContext.getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)); + assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(authorization); + assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN); + assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(jwtEncodingContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(jwtEncodingContext.getHeaders()).isNotNull(); + assertThat(jwtEncodingContext.getClaims()).isNotNull(); + ArgumentCaptor jwtClaimsSetCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class); verify(this.jwtEncoder).encode(any(), jwtClaimsSetCaptor.capture()); JwtClaimsSet jwtClaimsSet = jwtClaimsSetCaptor.getValue(); @@ -264,6 +290,29 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); + verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture()); + // Access Token context + JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0); + assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(accessTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)); + assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(accessTokenContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN); + assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(accessTokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(accessTokenContext.getHeaders()).isNotNull(); + assertThat(accessTokenContext.getClaims()).isNotNull(); + // ID Token context + JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1); + assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(idTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)); + assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN); + assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(idTokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(idTokenContext.getHeaders()).isNotNull(); + assertThat(idTokenContext.getClaims()).isNotNull(); + verify(this.jwtEncoder, times(2)).encode(any(), any()); // Access token and ID Token ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index 4a686fb..69a2133 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -15,28 +15,34 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.jwt.JoseHeaderNames; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtEncoder; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; - import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.Set; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; @@ -53,6 +59,7 @@ import static org.mockito.Mockito.when; public class OAuth2ClientCredentialsAuthenticationProviderTests { private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer; private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider; @Before @@ -61,6 +68,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { this.jwtEncoder = mock(JwtEncoder.class); this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider( this.authorizationService, this.jwtEncoder); + this.jwtCustomizer = mock(OAuth2TokenCustomizer.class); + this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer); } @Test @@ -77,6 +86,13 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { .hasMessage("jwtEncoder cannot be null"); } + @Test + public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtCustomizer cannot be null"); + } + @Test public void supportsWhenSupportedAuthenticationThenTrue() { assertThat(this.authenticationProvider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isTrue(); @@ -152,7 +168,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope); - when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(requestedScope)); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -165,11 +181,23 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); - when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(registeredClient.getScopes())); OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); + verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture()); + JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue(); + assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(jwtEncodingContext.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(jwtEncodingContext.getAuthorization()).isNull(); + assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN); + assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); + assertThat(jwtEncodingContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(jwtEncodingContext.getHeaders()).isNotNull(); + assertThat(jwtEncodingContext.getClaims()).isNotNull(); + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization authorization = authorizationCaptor.getValue(); @@ -182,13 +210,14 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); } - private static Jwt createJwt() { + private static Jwt createJwt(Set scope) { Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); return Jwt.withTokenValue("token") .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) .issuedAt(issuedAt) .expiresAt(expiresAt) + .claim(OAuth2ParameterNames.SCOPE, scope) .build(); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java index 8339b12..bb8e017 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -15,20 +15,30 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; -import org.springframework.security.oauth2.jwt.JoseHeaderNames; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -36,14 +46,10 @@ import org.springframework.security.oauth2.server.authorization.TestOAuth2Author import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.HashSet; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -56,25 +62,24 @@ import static org.mockito.Mockito.when; * Tests for {@link OAuth2RefreshTokenAuthenticationProvider}. * * @author Alexey Nesterov + * @author Joe Grandja * @since 0.0.3 */ public class OAuth2RefreshTokenAuthenticationProviderTests { private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer; private OAuth2RefreshTokenAuthenticationProvider authenticationProvider; @Before public void setUp() { this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); - Jwt jwt = Jwt.withTokenValue("refreshed-access-token") - .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) - .issuedAt(Instant.now()) - .expiresAt(Instant.now().plus(1, ChronoUnit.HOURS)) - .build(); - when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(Collections.singleton("scope1"))); this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider( this.authorizationService, this.jwtEncoder); + this.jwtCustomizer = mock(OAuth2TokenCustomizer.class); + this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer); } @Test @@ -93,6 +98,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { .isEqualTo("jwtEncoder cannot be null"); } + @Test + public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtCustomizer cannot be null"); + } + @Test public void supportsWhenSupportedAuthenticationThenTrue() { assertThat(this.authenticationProvider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue(); @@ -119,6 +131,18 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); + verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture()); + JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue(); + assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(jwtEncodingContext.getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)); + assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(authorization); + assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN); + assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(jwtEncodingContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(jwtEncodingContext.getHeaders()).isNotNull(); + assertThat(jwtEncodingContext.getClaims()).isNotNull(); + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(this.authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); @@ -340,4 +364,15 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); } + + private static Jwt createJwt(Set scope) { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); + return Jwt.withTokenValue("refreshed-access-token") + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .claim(OAuth2ParameterNames.SCOPE, scope) + .build(); + } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContextTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContextTests.java new file mode 100644 index 0000000..88f4ea7 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContextTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import org.junit.Test; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.TestJoseHeaders; +import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link JwtEncodingContext}. + * + * @author Joe Grandja + */ +public class JwtEncodingContextTests { + + @Test + public void withWhenHeadersNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JwtEncodingContext.with(null, TestJwtClaimsSets.jwtClaimsSet())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("headersBuilder cannot be null"); + } + + @Test + public void withWhenClaimsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JwtEncodingContext.with(TestJoseHeaders.joseHeader(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("claimsBuilder cannot be null"); + } + + @Test + public void setWhenValueNullThenThrowIllegalArgumentException() { + JwtEncodingContext.Builder builder = JwtEncodingContext + .with(TestJoseHeaders.joseHeader(), TestJwtClaimsSets.jwtClaimsSet()); + assertThatThrownBy(() -> builder.registeredClient(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.principal(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.authorization(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.tokenType(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.authorizationGrantType(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.authorizationGrant(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.put(null, "")) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAllValuesProvidedThenAllValuesAreSet() { + JoseHeader.Builder headers = TestJoseHeaders.joseHeader(); + JwtClaimsSet.Builder claims = TestJwtClaimsSets.jwtClaimsSet(); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "password"); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); + OAuth2AuthorizationCodeAuthenticationToken authorizationGrant = + new OAuth2AuthorizationCodeAuthenticationToken( + "code", clientPrincipal, authorizationRequest.getRedirectUri(), null); + + JwtEncodingContext context = JwtEncodingContext.with(headers, claims) + .registeredClient(registeredClient) + .principal(principal) + .authorization(authorization) + .tokenType(TokenType.ACCESS_TOKEN) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrant(authorizationGrant) + .put("custom-key-1", "custom-value-1") + .context(ctx -> ctx.put("custom-key-2", "custom-value-2")) + .build(); + + assertThat(context.getHeaders()).isEqualTo(headers); + assertThat(context.getClaims()).isEqualTo(claims); + assertThat(context.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(context.getPrincipal()).isEqualTo(principal); + assertThat(context.getAuthorization()).isEqualTo(authorization); + assertThat(context.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN); + assertThat(context.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(context.getAuthorizationGrant()).isEqualTo(authorizationGrant); + assertThat(context.get("custom-key-1")).isEqualTo("custom-value-1"); + assertThat(context.get("custom-key-2")).isEqualTo("custom-value-2"); + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 61b5dd0..c7a53a5 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,15 +15,25 @@ */ package org.springframework.security.oauth2.server.authorization.web; +import java.nio.charset.StandardCharsets; +import java.util.Set; +import java.util.function.Consumer; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -44,13 +54,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.util.StringUtils; -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.nio.charset.StandardCharsets; -import java.util.Set; -import java.util.function.Consumer; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; @@ -464,6 +467,8 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = authorizationCaptor.getValue(); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) + .isEqualTo(this.authentication); OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); assertThat(authorizationCode).isNotNull(); @@ -511,6 +516,8 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = authorizationCaptor.getValue(); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) + .isEqualTo(this.authentication); OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); assertThat(authorizationCode).isNotNull(); @@ -556,6 +563,8 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = authorizationCaptor.getValue(); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + assertThat(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) + .isEqualTo(this.authentication); String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE); assertThat(state).isNotNull();