Add OAuth2TokenCustomizer

Closes gh-199
This commit is contained in:
Joe Grandja 2021-01-18 09:31:06 -05:00
parent 3f310eec00
commit adf96b4e25
25 changed files with 1126 additions and 279 deletions

View File

@ -33,7 +33,9 @@ import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer; import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@ -152,23 +154,33 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
builder.authenticationProvider(postProcess(clientAuthenticationProvider)); builder.authenticationProvider(postProcess(clientAuthenticationProvider));
JwtEncoder jwtEncoder = getJwtEncoder(builder); JwtEncoder jwtEncoder = getJwtEncoder(builder);
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = getJwtCustomizer(builder);
OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
new OAuth2AuthorizationCodeAuthenticationProvider( new OAuth2AuthorizationCodeAuthenticationProvider(
getAuthorizationService(builder), getAuthorizationService(builder),
jwtEncoder); jwtEncoder);
if (jwtCustomizer != null) {
authorizationCodeAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
}
builder.authenticationProvider(postProcess(authorizationCodeAuthenticationProvider)); builder.authenticationProvider(postProcess(authorizationCodeAuthenticationProvider));
OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider = OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider =
new OAuth2RefreshTokenAuthenticationProvider( new OAuth2RefreshTokenAuthenticationProvider(
getAuthorizationService(builder), getAuthorizationService(builder),
jwtEncoder); jwtEncoder);
if (jwtCustomizer != null) {
refreshTokenAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
}
builder.authenticationProvider(postProcess(refreshTokenAuthenticationProvider)); builder.authenticationProvider(postProcess(refreshTokenAuthenticationProvider));
OAuth2ClientCredentialsAuthenticationProvider clientCredentialsAuthenticationProvider = OAuth2ClientCredentialsAuthenticationProvider clientCredentialsAuthenticationProvider =
new OAuth2ClientCredentialsAuthenticationProvider( new OAuth2ClientCredentialsAuthenticationProvider(
getAuthorizationService(builder), getAuthorizationService(builder),
jwtEncoder); jwtEncoder);
if (jwtCustomizer != null) {
clientCredentialsAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
}
builder.authenticationProvider(postProcess(clientCredentialsAuthenticationProvider)); builder.authenticationProvider(postProcess(clientCredentialsAuthenticationProvider));
OAuth2TokenRevocationAuthenticationProvider tokenRevocationAuthenticationProvider = OAuth2TokenRevocationAuthenticationProvider tokenRevocationAuthenticationProvider =
@ -314,6 +326,19 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
return jwkSource; return jwkSource;
} }
@SuppressWarnings("unchecked")
private static <B extends HttpSecurityBuilder<B>> OAuth2TokenCustomizer<JwtEncodingContext> getJwtCustomizer(B builder) {
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = builder.getSharedObject(OAuth2TokenCustomizer.class);
if (jwtCustomizer == null) {
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, JwtEncodingContext.class);
jwtCustomizer = getOptionalBean(builder, type);
if (jwtCustomizer != null) {
builder.setSharedObject(OAuth2TokenCustomizer.class, jwtCustomizer);
}
}
return jwtCustomizer;
}
private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettings(B builder) { private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettings(B builder) {
ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class); ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class);
if (providerSettings == null) { if (providerSettings == null) {
@ -353,4 +378,14 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
} }
return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null); return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null);
} }
@SuppressWarnings("unchecked")
private static <B extends HttpSecurityBuilder<B>, T> T getOptionalBean(B builder, ResolvableType type) {
ApplicationContext context = builder.getSharedObject(ApplicationContext.class);
String[] names = context.getBeanNamesForType(type);
if (names.length > 1) {
throw new NoUniqueBeanDefinitionException(type, names);
}
return names.length == 1 ? (T) context.getBean(names[0]) : null;
}
} }

View File

@ -0,0 +1,47 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.core.context;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* A facility for holding information associated to a specific context.
*
* @author Joe Grandja
* @since 0.1.0
*/
public interface Context {
@Nullable
<V> V get(Object key);
@Nullable
default <V> V get(Class<V> key) {
Assert.notNull(key, "key cannot be null");
V value = get((Object) key);
return key.isInstance(value) ? value : null;
}
boolean hasKey(Object key);
static Context of(Map<Object, Object> context) {
return new DefaultContext(context);
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.core.context;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* @author Joe Grandja
* @since 0.1.0
*/
final class DefaultContext implements Context {
private final Map<Object, Object> context;
DefaultContext(Map<Object, Object> context) {
Assert.notNull(context, "context cannot be null");
this.context = Collections.unmodifiableMap(new HashMap<>(context));
}
@SuppressWarnings("unchecked")
@Override
@Nullable
public <V> V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@Override
public boolean hasKey(Object key) {
Assert.notNull(key, "key cannot be null");
return this.context.containsKey(key);
}
}

View File

@ -23,8 +23,6 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEException;
@ -46,7 +44,6 @@ import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.SignedJWT;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -88,9 +85,6 @@ public final class NimbusJwsEncoder implements JwtEncoder {
private final JWKSource<SecurityContext> jwkSource; private final JWKSource<SecurityContext> jwkSource;
private BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = (headers, claims) -> {
};
/** /**
* Constructs a {@code NimbusJwsEncoder} using the provided parameters. * Constructs a {@code NimbusJwsEncoder} using the provided parameters.
* @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource}
@ -100,32 +94,12 @@ public final class NimbusJwsEncoder implements JwtEncoder {
this.jwkSource = jwkSource; this.jwkSource = jwkSource;
} }
/**
* Sets the {@link Jwt} customizer to be provided the {@link JoseHeader.Builder} and
* {@link JwtClaimsSet.Builder} allowing for further customizations.
* @param jwtCustomizer the {@link Jwt} customizer to be provided the
* {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
*/
public void setJwtCustomizer(BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer) {
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
this.jwtCustomizer = jwtCustomizer;
}
@Override @Override
public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException {
Assert.notNull(headers, "headers cannot be null"); Assert.notNull(headers, "headers cannot be null");
Assert.notNull(claims, "claims cannot be null"); Assert.notNull(claims, "claims cannot be null");
// @formatter:off JWK jwk = selectJwk(headers);
JoseHeader.Builder headersBuilder = JoseHeader.from(headers)
.type(JOSEObjectType.JWT.getType());
JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims)
.id(UUID.randomUUID().toString());
// @formatter:on
this.jwtCustomizer.accept(headersBuilder, claimsBuilder);
JWK jwk = selectJwk(headersBuilder);
if (jwk == null) { if (jwk == null) {
throw new JwtEncodingException( throw new JwtEncodingException(
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
@ -135,8 +109,15 @@ public final class NimbusJwsEncoder implements JwtEncoder {
"The \"kid\" (key ID) from the selected JWK cannot be empty")); "The \"kid\" (key ID) from the selected JWK cannot be empty"));
} }
headers = headersBuilder.keyId(jwk.getKeyID()).build(); // @formatter:off
claims = claimsBuilder.build(); headers = JoseHeader.from(headers)
.type(JOSEObjectType.JWT.getType())
.keyId(jwk.getKeyID())
.build();
claims = JwtClaimsSet.from(claims)
.id(UUID.randomUUID().toString())
.build();
// @formatter:on
JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers); JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers);
JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims);
@ -164,13 +145,9 @@ public final class NimbusJwsEncoder implements JwtEncoder {
return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims()); return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims());
} }
private JWK selectJwk(JoseHeader.Builder headersBuilder) { private JWK selectJwk(JoseHeader headers) {
final AtomicReference<JWSAlgorithm> jwsAlgorithm = new AtomicReference<>(); JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getJwsAlgorithm().getName());
headersBuilder.headers((h) -> { JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm);
JwsAlgorithm jwsAlg = (JwsAlgorithm) h.get(JoseHeaderNames.ALG);
jwsAlgorithm.set(JWSAlgorithm.parse(jwsAlg.getName()));
});
JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm.get());
JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader)); JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
List<JWK> jwks; List<JWK> jwks;
@ -184,7 +161,7 @@ public final class NimbusJwsEncoder implements JwtEncoder {
if (jwks.size() > 1) { if (jwks.size() > 1) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.get().getName() + "'")); "Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.getName() + "'"));
} }
return !jwks.isEmpty() ? jwks.get(0) : null; return !jwks.isEmpty() ? jwks.get(0) : null;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -56,4 +56,9 @@ public interface OAuth2AuthorizationAttributeNames {
*/ */
String ACCESS_TOKEN_ATTRIBUTES = OAuth2Authorization.class.getName().concat(".ACCESS_TOKEN_ATTRIBUTES"); String ACCESS_TOKEN_ATTRIBUTES = OAuth2Authorization.class.getName().concat(".ACCESS_TOKEN_ATTRIBUTES");
/**
* The name of the attribute used for the resource owner {@code Principal}.
*/
String PRINCIPAL = OAuth2Authorization.class.getName().concat(".PRINCIPAL");
} }

View File

@ -0,0 +1,123 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.Set;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/**
* @author Joe Grandja
* @since 0.1.0
*/
final class JwtEncodingContextUtils {
private JwtEncodingContextUtils() {
}
static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization) {
// @formatter:off
return accessTokenContext(registeredClient, authorization,
authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES));
// @formatter:on
}
static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization,
Set<String> authorizedScopes) {
// @formatter:off
return accessTokenContext(registeredClient, authorization.getPrincipalName(), authorizedScopes)
.authorization(authorization);
// @formatter:on
}
static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient,
String principalName, Set<String> authorizedScopes) {
JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256);
String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().accessTokenTimeToLive());
// @formatter:off
JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder()
.issuer(issuer)
.subject(principalName)
.audience(Collections.singletonList(registeredClient.getClientId()))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.notBefore(issuedAt);
if (!CollectionUtils.isEmpty(authorizedScopes)) {
claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes);
}
// @formatter:on
// @formatter:off
return JwtEncodingContext.with(headersBuilder, claimsBuilder)
.registeredClient(registeredClient)
.tokenType(TokenType.ACCESS_TOKEN);
// @formatter:on
}
static JwtEncodingContext.Builder idTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization) {
JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256);
String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); // TODO Allow configuration for id token time-to-live
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
// @formatter:off
JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder()
.issuer(issuer)
.subject(authorization.getPrincipalName())
.audience(Collections.singletonList(registeredClient.getClientId()))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.claim(IdTokenClaimNames.AZP, registeredClient.getClientId());
if (StringUtils.hasText(nonce)) {
claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce);
}
// TODO Add 'auth_time' claim
// @formatter:on
// @formatter:off
return JwtEncodingContext.with(headersBuilder, claimsBuilder)
.registeredClient(registeredClient)
.authorization(authorization)
.tokenType(new TokenType(OidcParameterNames.ID_TOKEN));
// @formatter:on
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,6 +15,10 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
@ -25,27 +29,27 @@ import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
/** /**
@ -58,12 +62,15 @@ import static org.springframework.security.oauth2.server.authorization.authentic
* @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken
* @see OAuth2AuthorizationService * @see OAuth2AuthorizationService
* @see JwtEncoder * @see JwtEncoder
* @see OAuth2TokenCustomizer
* @see JwtEncodingContext
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
*/ */
public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
private final OAuth2AuthorizationService authorizationService; private final OAuth2AuthorizationService authorizationService;
private final JwtEncoder jwtEncoder; private final JwtEncoder jwtEncoder;
private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
/** /**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
@ -78,6 +85,11 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
this.jwtEncoder = jwtEncoder; this.jwtEncoder = jwtEncoder;
} }
public final void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
this.jwtCustomizer = jwtCustomizer;
}
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
@ -116,27 +128,46 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
Set<String> authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); // @formatter:off
Jwt jwt = OAuth2TokenIssuerUtil.issueJwtAccessToken( JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, authorization)
this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), .principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
authorizedScopes, registeredClient.getTokenSettings().accessTokenTimeToLive()); .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, .authorizationGrant(authorizationCodeAuthentication)
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), authorizedScopes); .build();
// @formatter:on
this.jwtCustomizer.customize(context);
JoseHeader headers = context.getHeaders().build();
JwtClaimsSet claims = context.getClaims().build();
Jwt jwt = this.jwtEncoder.encode(headers, claims);
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens()) OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens())
.accessToken(accessToken); .accessToken(accessToken);
OAuth2RefreshToken refreshToken = null; OAuth2RefreshToken refreshToken = null;
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) { if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(registeredClient.getTokenSettings().refreshTokenTimeToLive()); refreshToken = OAuth2RefreshTokenAuthenticationProvider.generateRefreshToken(
registeredClient.getTokenSettings().refreshTokenTimeToLive());
tokensBuilder.refreshToken(refreshToken); tokensBuilder.refreshToken(refreshToken);
} }
OidcIdToken idToken = null; OidcIdToken idToken = null;
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) { if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) {
Jwt jwtIdToken = OAuth2TokenIssuerUtil.issueIdToken( // @formatter:off
this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), context = JwtEncodingContextUtils.idTokenContext(registeredClient, authorization)
(String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE)); .principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.authorizationGrant(authorizationCodeAuthentication)
.build();
// @formatter:on
this.jwtCustomizer.customize(context);
headers = context.getHeaders().build();
claims = context.getClaims().build();
Jwt jwtIdToken = this.jwtEncoder.encode(headers, claims);
idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(), idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
jwtIdToken.getExpiresAt(), jwtIdToken.getClaims()); jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
tokensBuilder.token(idToken); tokensBuilder.token(idToken);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,6 +15,10 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
@ -23,37 +27,42 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.stream.Collectors;
import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
/** /**
* An {@link AuthenticationProvider} implementation for the OAuth 2.0 Client Credentials Grant. * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Client Credentials Grant.
* *
* @author Alexey Nesterov * @author Alexey Nesterov
* @author Joe Grandja
* @since 0.0.1 * @since 0.0.1
* @see OAuth2ClientCredentialsAuthenticationToken * @see OAuth2ClientCredentialsAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken
* @see OAuth2AuthorizationService * @see OAuth2AuthorizationService
* @see JwtEncoder * @see JwtEncoder
* @see OAuth2TokenCustomizer
* @see JwtEncodingContext
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4">Section 4.4 Client Credentials Grant</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4">Section 4.4 Client Credentials Grant</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request</a>
*/ */
public class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider { public class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider {
private final OAuth2AuthorizationService authorizationService; private final OAuth2AuthorizationService authorizationService;
private final JwtEncoder jwtEncoder; private final JwtEncoder jwtEncoder;
private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
/** /**
* Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters. * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters.
@ -69,6 +78,11 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
this.jwtEncoder = jwtEncoder; this.jwtEncoder = jwtEncoder;
} }
public final void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
this.jwtCustomizer = jwtCustomizer;
}
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication = OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication =
@ -93,10 +107,21 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes()); scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes());
} }
Jwt jwt = OAuth2TokenIssuerUtil // @formatter:off
.issueJwtAccessToken(this.jwtEncoder, clientPrincipal.getName(), registeredClient.getClientId(), scopes, registeredClient.getTokenSettings().accessTokenTimeToLive()); JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, clientPrincipal.getName(), scopes)
.principal(clientPrincipal)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.authorizationGrant(clientCredentialsAuthentication)
.build();
// @formatter:on
this.jwtCustomizer.customize(context);
JoseHeader headers = context.getHeaders().build();
JwtClaimsSet claims = context.getClaims().build();
Jwt jwt = this.jwtEncoder.encode(headers, claims);
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes); jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName(clientPrincipal.getName()) .principalName(clientPrincipal.getName())

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,47 +15,62 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Set;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.config.TokenSettings; import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.time.Instant;
import java.util.Set;
import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
/** /**
* An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant. * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant.
* *
* @author Alexey Nesterov * @author Alexey Nesterov
* @author Joe Grandja
* @since 0.0.3 * @since 0.0.3
* @see OAuth2RefreshTokenAuthenticationToken * @see OAuth2RefreshTokenAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken
* @see OAuth2AuthorizationService * @see OAuth2AuthorizationService
* @see JwtEncoder * @see JwtEncoder
* @see OAuth2TokenCustomizer
* @see JwtEncodingContext
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.5">Section 1.5 Refresh Token Grant</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.5">Section 1.5 Refresh Token Grant</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a> * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
*/ */
public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider { public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private final OAuth2AuthorizationService authorizationService; private final OAuth2AuthorizationService authorizationService;
private final JwtEncoder jwtEncoder; private final JwtEncoder jwtEncoder;
private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
/** /**
* Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters. * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters.
@ -71,6 +86,11 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
this.jwtEncoder = jwtEncoder; this.jwtEncoder = jwtEncoder;
} }
public final void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
this.jwtCustomizer = jwtCustomizer;
}
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2RefreshTokenAuthenticationToken refreshTokenAuthentication = OAuth2RefreshTokenAuthenticationToken refreshTokenAuthentication =
@ -121,15 +141,26 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
Jwt jwt = OAuth2TokenIssuerUtil // @formatter:off
.issueJwtAccessToken(this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), scopes, registeredClient.getTokenSettings().accessTokenTimeToLive()); JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, authorization, scopes)
.principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.authorizationGrant(refreshTokenAuthentication)
.build();
// @formatter:on
this.jwtCustomizer.customize(context);
JoseHeader headers = context.getHeaders().build();
JwtClaimsSet claims = context.getClaims().build();
Jwt jwt = this.jwtEncoder.encode(headers, claims);
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes); jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
TokenSettings tokenSettings = registeredClient.getTokenSettings(); TokenSettings tokenSettings = registeredClient.getTokenSettings();
if (!tokenSettings.reuseRefreshTokens()) { if (!tokenSettings.reuseRefreshTokens()) {
refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(tokenSettings.refreshTokenTimeToLive()); refreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive());
} }
authorization = OAuth2Authorization.from(authorization) authorization = OAuth2Authorization.from(authorization)
@ -146,4 +177,10 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
public boolean supports(Class<?> authentication) { public boolean supports(Class<?> authentication) {
return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication); return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
} }
static OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(tokenTimeToLive);
return new OAuth2RefreshToken2(TOKEN_GENERATOR.generateKey(), issuedAt, expiresAt);
}
} }

View File

@ -1,97 +0,0 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.util.StringUtils;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Collections;
import java.util.Set;
/**
* @author Alexey Nesterov
* @since 0.0.3
*/
class OAuth2TokenIssuerUtil {
private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
static Jwt issueJwtAccessToken(JwtEncoder jwtEncoder, String subject, String audience, Set<String> scopes, Duration tokenTimeToLive) {
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(tokenTimeToLive);
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.builder()
.issuer(issuer)
.subject(subject)
.audience(Collections.singletonList(audience))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.notBefore(issuedAt)
.claim(OAuth2ParameterNames.SCOPE, scopes)
.build();
return jwtEncoder.encode(joseHeader, jwtClaimsSet);
}
static Jwt issueIdToken(JwtEncoder jwtEncoder, String subject, String audience, String nonce) {
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
String issuer = "http://auth-server:9000"; // TODO Allow configuration for issuer claim
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); // TODO Allow configuration for id token time-to-live
JwtClaimsSet.Builder builder = JwtClaimsSet.builder()
.issuer(issuer)
.subject(subject)
.audience(Collections.singletonList(audience))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.claim(IdTokenClaimNames.AZP, audience);
if (StringUtils.hasText(nonce)) {
builder.claim(IdTokenClaimNames.NONCE, nonce);
}
// TODO Add 'auth_time' claim
JwtClaimsSet jwtClaimsSet = builder.build();
return jwtEncoder.encode(joseHeader, jwtClaimsSet);
}
static OAuth2RefreshToken issueRefreshToken(Duration tokenTimeToLive) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(tokenTimeToLive);
return new OAuth2RefreshToken2(TOKEN_GENERATOR.generateKey(), issuedAt, expiresAt);
}
}

View File

@ -0,0 +1,87 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.context.Context;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.util.Assert;
/**
* @author Joe Grandja
* @since 0.1.0
* @see OAuth2TokenContext
* @see JoseHeader.Builder
* @see JwtClaimsSet.Builder
*/
public final class JwtEncodingContext implements OAuth2TokenContext {
private final Context context;
private JwtEncodingContext(Map<Object, Object> context) {
this.context = Context.of(context);
}
@Nullable
@Override
public <V> V get(Object key) {
return this.context.get(key);
}
@Override
public boolean hasKey(Object key) {
return this.context.hasKey(key);
}
public JoseHeader.Builder getHeaders() {
return get(JoseHeader.Builder.class);
}
public JwtClaimsSet.Builder getClaims() {
return get(JwtClaimsSet.Builder.class);
}
public static Builder with(JoseHeader.Builder headersBuilder, JwtClaimsSet.Builder claimsBuilder) {
return new Builder(headersBuilder, claimsBuilder);
}
public static final class Builder extends AbstractBuilder<JwtEncodingContext, Builder> {
private Builder(JoseHeader.Builder headersBuilder, JwtClaimsSet.Builder claimsBuilder) {
Assert.notNull(headersBuilder, "headersBuilder cannot be null");
Assert.notNull(claimsBuilder, "claimsBuilder cannot be null");
put(JoseHeader.Builder.class, headersBuilder);
put(JwtClaimsSet.Builder.class, claimsBuilder);
}
public Builder headers(Consumer<JoseHeader.Builder> headersConsumer) {
headersConsumer.accept(get(JoseHeader.Builder.class));
return this;
}
public Builder claims(Consumer<JwtClaimsSet.Builder> claimsConsumer) {
claimsConsumer.accept(get(JwtClaimsSet.Builder.class));
return this;
}
public JwtEncodingContext build() {
return new JwtEncodingContext(this.context);
}
}
}

View File

@ -0,0 +1,119 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.context.Context;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
/**
* @author Joe Grandja
* @since 0.1.0
* @see Context
*/
public interface OAuth2TokenContext extends Context {
default RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
}
default <T extends Authentication> T getPrincipal() {
return get(AbstractBuilder.PRINCIPAL_AUTHENTICATION_KEY);
}
@Nullable
default OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class);
}
default TokenType getTokenType() {
return get(TokenType.class);
}
default AuthorizationGrantType getAuthorizationGrantType() {
return get(AuthorizationGrantType.class);
}
default <T extends Authentication> T getAuthorizationGrant() {
return get(AbstractBuilder.AUTHORIZATION_GRANT_AUTHENTICATION_KEY);
}
abstract class AbstractBuilder<T extends OAuth2TokenContext, B extends AbstractBuilder<T, B>> {
private static final String PRINCIPAL_AUTHENTICATION_KEY =
Authentication.class.getName().concat(".PRINCIPAL");
private static final String AUTHORIZATION_GRANT_AUTHENTICATION_KEY =
Authentication.class.getName().concat(".AUTHORIZATION_GRANT");
protected final Map<Object, Object> context = new HashMap<>();
public B registeredClient(RegisteredClient registeredClient) {
return put(RegisteredClient.class, registeredClient);
}
public B principal(Authentication principal) {
return put(PRINCIPAL_AUTHENTICATION_KEY, principal);
}
public B authorization(OAuth2Authorization authorization) {
return put(OAuth2Authorization.class, authorization);
}
public B tokenType(TokenType tokenType) {
return put(TokenType.class, tokenType);
}
public B authorizationGrantType(AuthorizationGrantType authorizationGrantType) {
return put(AuthorizationGrantType.class, authorizationGrantType);
}
public B authorizationGrant(Authentication authorizationGrant) {
return put(AUTHORIZATION_GRANT_AUTHENTICATION_KEY, authorizationGrant);
}
public B put(Object key, Object value) {
Assert.notNull(key, "key cannot be null");
Assert.notNull(value, "value cannot be null");
this.context.put(key, value);
return getThis();
}
public B context(Consumer<Map<Object, Object>> contextConsumer) {
contextConsumer.accept(this.context);
return getThis();
}
@SuppressWarnings("unchecked")
protected <V> V get(Object key) {
return (V) this.context.get(key);
}
@SuppressWarnings("unchecked")
protected B getThis() {
return (B) this;
}
public abstract T build();
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
/**
* @author Joe Grandja
* @since 0.1.0
* @see OAuth2TokenContext
*/
@FunctionalInterface
public interface OAuth2TokenCustomizer<C extends OAuth2TokenContext> {
void customize(C context);
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,6 +15,22 @@
*/ */
package org.springframework.security.oauth2.server.authorization.web; package org.springframework.security.oauth2.server.authorization.web;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@ -53,21 +69,6 @@ import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/** /**
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant, * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
* which handles the processing of the OAuth 2.0 Authorization Request. * which handles the processing of the OAuth 2.0 Authorization Request.
@ -193,6 +194,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest(); OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest();
OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient) OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName(principal.getName()) .principalName(principal.getName())
.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, principal)
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest); .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest);
if (registeredClient.getClientSettings().requireUserConsent()) { if (registeredClient.getClientSettings().requireUserConsent()) {

View File

@ -18,7 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se
import java.net.URLEncoder; import java.net.URLEncoder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Base64; import java.util.Base64;
import java.util.function.BiConsumer; import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
@ -33,19 +35,29 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
@ -53,7 +65,9 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -90,13 +104,16 @@ public class OAuth2AuthorizationCodeGrantTests {
// https://tools.ietf.org/html/rfc7636#appendix-B // https://tools.ietf.org/html/rfc7636#appendix-B
private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
private static final String AUTHORITIES_CLAIM = "authorities";
private static RegisteredClientRepository registeredClientRepository; private static RegisteredClientRepository registeredClientRepository;
private static OAuth2AuthorizationService authorizationService; private static OAuth2AuthorizationService authorizationService;
private static JWKSource<SecurityContext> jwkSource; private static JWKSource<SecurityContext> jwkSource;
private static NimbusJwsEncoder jwtEncoder; private static NimbusJwsEncoder jwtEncoder;
private static BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer; private static NimbusJwtDecoder jwtDecoder;
private static ProviderSettings providerSettings; private static ProviderSettings providerSettings;
private static HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
new OAuth2AccessTokenResponseHttpMessageConverter();
@Rule @Rule
public final SpringTestRule spring = new SpringTestRule(); public final SpringTestRule spring = new SpringTestRule();
@ -111,8 +128,7 @@ public class OAuth2AuthorizationCodeGrantTests {
JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
jwtEncoder = new NimbusJwsEncoder(jwkSource); jwtEncoder = new NimbusJwsEncoder(jwkSource);
jwtCustomizer = mock(BiConsumer.class); jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build();
jwtEncoder.setJwtCustomizer(jwtCustomizer);
providerSettings = new ProviderSettings() providerSettings = new ProviderSettings()
.authorizationEndpoint("/test/authorize") .authorizationEndpoint("/test/authorize")
.tokenEndpoint("/test/token"); .tokenEndpoint("/test/token");
@ -186,8 +202,17 @@ public class OAuth2AuthorizationCodeGrantTests {
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
assertTokenRequestReturnsAccessTokenResponse( OAuth2AccessTokenResponse accessTokenResponse = assertTokenRequestReturnsAccessTokenResponse(
registeredClient, authorization, OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI); registeredClient, authorization, OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI);
// Assert user authorities was propagated as claim in JWT
Jwt jwt = jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue());
List<String> authoritiesClaim = jwt.getClaim(AUTHORITIES_CLAIM);
Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL);
Set<String> userAuthorities = principal.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities);
} }
@Test @Test
@ -208,10 +233,10 @@ public class OAuth2AuthorizationCodeGrantTests {
registeredClient, authorization, providerSettings.tokenEndpoint()); registeredClient, authorization, providerSettings.tokenEndpoint());
} }
private void assertTokenRequestReturnsAccessTokenResponse(RegisteredClient registeredClient, private OAuth2AccessTokenResponse assertTokenRequestReturnsAccessTokenResponse(RegisteredClient registeredClient,
OAuth2Authorization authorization, String tokenEndpointUri) throws Exception { OAuth2Authorization authorization, String tokenEndpointUri) throws Exception {
this.mvc.perform(post(tokenEndpointUri) MvcResult mvcResult = this.mvc.perform(post(tokenEndpointUri)
.params(getTokenRequestParameters(registeredClient, authorization)) .params(getTokenRequestParameters(registeredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
registeredClient.getClientId(), registeredClient.getClientSecret()))) registeredClient.getClientId(), registeredClient.getClientSecret())))
@ -222,13 +247,19 @@ public class OAuth2AuthorizationCodeGrantTests {
.andExpect(jsonPath("$.token_type").isNotEmpty()) .andExpect(jsonPath("$.token_type").isNotEmpty())
.andExpect(jsonPath("$.expires_in").isNotEmpty()) .andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").isNotEmpty()) .andExpect(jsonPath("$.refresh_token").isNotEmpty())
.andExpect(jsonPath("$.scope").isNotEmpty()); .andExpect(jsonPath("$.scope").isNotEmpty())
.andReturn();
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken( verify(authorizationService).findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE)); eq(TokenType.AUTHORIZATION_CODE));
verify(authorizationService).save(any()); verify(authorizationService).save(any());
MockHttpServletResponse servletResponse = mvcResult.getResponse();
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
return accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
} }
@Test @Test
@ -295,8 +326,6 @@ public class OAuth2AuthorizationCodeGrantTests {
.params(getTokenRequestParameters(registeredClient, authorization)) .params(getTokenRequestParameters(registeredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
registeredClient.getClientId(), registeredClient.getClientSecret()))); registeredClient.getClientId(), registeredClient.getClientSecret())));
verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
} }
private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) { private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
@ -345,6 +374,20 @@ public class OAuth2AuthorizationCodeGrantTests {
JWKSource<SecurityContext> jwkSource() { JWKSource<SecurityContext> jwkSource() {
return jwkSource; return jwkSource;
} }
@Bean
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
return context -> {
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) &&
TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
Authentication principal = context.getPrincipal();
Set<String> authorities = principal.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
context.getClaims().claim(AUTHORITIES_CLAIM, authorities);
}
};
}
} }
@EnableWebSecurity @EnableWebSecurity

View File

@ -41,6 +41,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
@ -65,6 +67,7 @@ public class OAuth2ClientCredentialsGrantTests {
private static RegisteredClientRepository registeredClientRepository; private static RegisteredClientRepository registeredClientRepository;
private static OAuth2AuthorizationService authorizationService; private static OAuth2AuthorizationService authorizationService;
private static JWKSource<SecurityContext> jwkSource; private static JWKSource<SecurityContext> jwkSource;
private static OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
@Rule @Rule
public final SpringTestRule spring = new SpringTestRule(); public final SpringTestRule spring = new SpringTestRule();
@ -78,10 +81,13 @@ public class OAuth2ClientCredentialsGrantTests {
authorizationService = mock(OAuth2AuthorizationService.class); authorizationService = mock(OAuth2AuthorizationService.class);
JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
jwtCustomizer = mock(OAuth2TokenCustomizer.class);
} }
@SuppressWarnings("unchecked")
@Before @Before
public void setup() { public void setup() {
reset(jwtCustomizer);
reset(registeredClientRepository); reset(registeredClientRepository);
reset(authorizationService); reset(authorizationService);
} }
@ -115,6 +121,7 @@ public class OAuth2ClientCredentialsGrantTests {
.andExpect(jsonPath("$.access_token").isNotEmpty()) .andExpect(jsonPath("$.access_token").isNotEmpty())
.andExpect(jsonPath("$.scope").value("scope1 scope2")); .andExpect(jsonPath("$.scope").value("scope1 scope2"));
verify(jwtCustomizer).customize(any());
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).save(any()); verify(authorizationService).save(any());
} }
@ -136,6 +143,7 @@ public class OAuth2ClientCredentialsGrantTests {
.andExpect(jsonPath("$.access_token").isNotEmpty()) .andExpect(jsonPath("$.access_token").isNotEmpty())
.andExpect(jsonPath("$.scope").value("scope1 scope2")); .andExpect(jsonPath("$.scope").value("scope1 scope2"));
verify(jwtCustomizer).customize(any());
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).save(any()); verify(authorizationService).save(any());
} }
@ -166,5 +174,10 @@ public class OAuth2ClientCredentialsGrantTests {
JWKSource<SecurityContext> jwkSource() { JWKSource<SecurityContext> jwkSource() {
return jwkSource; return jwkSource;
} }
@Bean
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
return jwtCustomizer;
}
} }
} }

View File

@ -18,6 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se
import java.net.URLEncoder; import java.net.URLEncoder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Base64; import java.util.Base64;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
@ -31,24 +34,40 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
@ -68,9 +87,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
* @since 0.0.3 * @since 0.0.3
*/ */
public class OAuth2RefreshTokenGrantTests { public class OAuth2RefreshTokenGrantTests {
private static final String AUTHORITIES_CLAIM = "authorities";
private static RegisteredClientRepository registeredClientRepository; private static RegisteredClientRepository registeredClientRepository;
private static OAuth2AuthorizationService authorizationService; private static OAuth2AuthorizationService authorizationService;
private static JWKSource<SecurityContext> jwkSource; private static JWKSource<SecurityContext> jwkSource;
private static NimbusJwtDecoder jwtDecoder;
private static HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
new OAuth2AccessTokenResponseHttpMessageConverter();
@Rule @Rule
public final SpringTestRule spring = new SpringTestRule(); public final SpringTestRule spring = new SpringTestRule();
@ -84,6 +107,7 @@ public class OAuth2RefreshTokenGrantTests {
authorizationService = mock(OAuth2AuthorizationService.class); authorizationService = mock(OAuth2AuthorizationService.class);
JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build();
} }
@Before @Before
@ -106,7 +130,7 @@ public class OAuth2RefreshTokenGrantTests {
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
.params(getRefreshTokenRequestParameters(authorization)) .params(getRefreshTokenRequestParameters(authorization))
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
registeredClient.getClientId(), registeredClient.getClientSecret()))) registeredClient.getClientId(), registeredClient.getClientSecret())))
@ -117,7 +141,8 @@ public class OAuth2RefreshTokenGrantTests {
.andExpect(jsonPath("$.token_type").isNotEmpty()) .andExpect(jsonPath("$.token_type").isNotEmpty())
.andExpect(jsonPath("$.expires_in").isNotEmpty()) .andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").isNotEmpty()) .andExpect(jsonPath("$.refresh_token").isNotEmpty())
.andExpect(jsonPath("$.scope").isNotEmpty()); .andExpect(jsonPath("$.scope").isNotEmpty())
.andReturn();
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken( verify(authorizationService).findByToken(
@ -125,6 +150,20 @@ public class OAuth2RefreshTokenGrantTests {
eq(TokenType.REFRESH_TOKEN)); eq(TokenType.REFRESH_TOKEN));
verify(authorizationService).save(any()); verify(authorizationService).save(any());
MockHttpServletResponse servletResponse = mvcResult.getResponse();
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(
OAuth2AccessTokenResponse.class, httpResponse);
// Assert user authorities was propagated as claim in JWT
Jwt jwt = jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue());
List<String> authoritiesClaim = jwt.getClaim(AUTHORITIES_CLAIM);
Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL);
Set<String> userAuthorities = principal.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities);
} }
private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) { private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
@ -160,5 +199,18 @@ public class OAuth2RefreshTokenGrantTests {
JWKSource<SecurityContext> jwkSource() { JWKSource<SecurityContext> jwkSource() {
return jwkSource; return jwkSource;
} }
@Bean
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
return context -> {
if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) {
Authentication principal = context.getPrincipal();
Set<String> authorities = principal.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
context.getClaims().claim(AUTHORITIES_CLAIM, authorities);
}
};
}
} }
} }

View File

@ -18,6 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se
import java.net.URLEncoder; import java.net.URLEncoder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Base64; import java.util.Base64;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
@ -32,15 +35,28 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@ -48,7 +64,9 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
import org.springframework.security.oauth2.server.authorization.oidc.web.OidcProviderConfigurationEndpointFilter; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcProviderConfigurationEndpointFilter;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
@ -80,10 +98,14 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
* @author Daniel Garnier-Moiroux * @author Daniel Garnier-Moiroux
*/ */
public class OidcTests { public class OidcTests {
private static final String issuerUrl = "https://example.com/issuer1"; private static final String ISSUER_URL = "https://example.com/issuer1";
private static final String AUTHORITIES_CLAIM = "authorities";
private static RegisteredClientRepository registeredClientRepository; private static RegisteredClientRepository registeredClientRepository;
private static OAuth2AuthorizationService authorizationService; private static OAuth2AuthorizationService authorizationService;
private static JWKSource<SecurityContext> jwkSource; private static JWKSource<SecurityContext> jwkSource;
private static NimbusJwtDecoder jwtDecoder;
private static HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
new OAuth2AccessTokenResponseHttpMessageConverter();
@Rule @Rule
public final SpringTestRule spring = new SpringTestRule(); public final SpringTestRule spring = new SpringTestRule();
@ -97,6 +119,7 @@ public class OidcTests {
authorizationService = mock(OAuth2AuthorizationService.class); authorizationService = mock(OAuth2AuthorizationService.class);
JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build();
} }
@Before @Before
@ -111,7 +134,7 @@ public class OidcTests {
this.mvc.perform(get(OidcProviderConfigurationEndpointFilter.DEFAULT_OIDC_PROVIDER_CONFIGURATION_ENDPOINT_URI)) this.mvc.perform(get(OidcProviderConfigurationEndpointFilter.DEFAULT_OIDC_PROVIDER_CONFIGURATION_ENDPOINT_URI))
.andExpect(status().is2xxSuccessful()) .andExpect(status().is2xxSuccessful())
.andExpect(jsonPath("issuer").value(issuerUrl)); .andExpect(jsonPath("issuer").value(ISSUER_URL));
} }
@Test @Test
@ -148,7 +171,7 @@ public class OidcTests {
MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient)) .params(getAuthorizationRequestParameters(registeredClient))
.with(user("user"))) .with(user("user").roles("A", "B")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
@ -164,7 +187,7 @@ public class OidcTests {
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) mvcResult = this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorization)) .params(getTokenRequestParameters(registeredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
registeredClient.getClientId(), registeredClient.getClientSecret()))) registeredClient.getClientId(), registeredClient.getClientSecret())))
@ -176,13 +199,28 @@ public class OidcTests {
.andExpect(jsonPath("$.expires_in").isNotEmpty()) .andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").isNotEmpty()) .andExpect(jsonPath("$.refresh_token").isNotEmpty())
.andExpect(jsonPath("$.scope").isNotEmpty()) .andExpect(jsonPath("$.scope").isNotEmpty())
.andExpect(jsonPath("$.id_token").isNotEmpty()); .andExpect(jsonPath("$.id_token").isNotEmpty())
.andReturn();
verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken( verify(authorizationService).findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE)); eq(TokenType.AUTHORIZATION_CODE));
verify(authorizationService, times(2)).save(any()); verify(authorizationService, times(2)).save(any());
MockHttpServletResponse servletResponse = mvcResult.getResponse();
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
// Assert user authorities was propagated as claim in ID Token
Jwt idToken = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
List<String> authoritiesClaim = idToken.getClaim(AUTHORITIES_CLAIM);
Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL);
Set<String> userAuthorities = principal.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities);
} }
private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) { private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
@ -231,6 +269,19 @@ public class OidcTests {
JWKSource<SecurityContext> jwkSource() { JWKSource<SecurityContext> jwkSource() {
return jwkSource; return jwkSource;
} }
@Bean
OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
return context -> {
if (context.getTokenType().getValue().equals(OidcParameterNames.ID_TOKEN)) {
Authentication principal = context.getPrincipal();
Set<String> authorities = principal.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
context.getClaims().claim(AUTHORITIES_CLAIM, authorities);
}
};
}
} }
@EnableWebSecurity @EnableWebSecurity
@ -239,7 +290,7 @@ public class OidcTests {
@Bean @Bean
ProviderSettings providerSettings() { ProviderSettings providerSettings() {
return new ProviderSettings().issuer(issuerUrl); return new ProviderSettings().issuer(ISSUER_URL);
} }
} }

View File

@ -21,7 +21,6 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer;
import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.ECKey;
@ -49,7 +48,6 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link NimbusJwsEncoder}. * Tests for {@link NimbusJwsEncoder}.
@ -77,12 +75,6 @@ public class NimbusJwsEncoderTests {
.withMessage("jwkSource cannot be null"); .withMessage("jwkSource cannot be null");
} }
@Test
public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.setJwtCustomizer(null))
.withMessage("jwtCustomizer cannot be null");
}
@Test @Test
public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { public void encodeWhenHeadersNullThenThrowIllegalArgumentException() {
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
@ -99,22 +91,6 @@ public class NimbusJwsEncoderTests {
.withMessage("claims cannot be null"); .withMessage("claims cannot be null");
} }
@Test
public void encodeWhenCustomizerSetThenCalled() {
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
this.jwkList.add(rsaJwk);
BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = mock(BiConsumer.class);
this.jwsEncoder.setJwtCustomizer(jwtCustomizer);
JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
this.jwsEncoder.encode(joseHeader, jwtClaimsSet);
verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
}
@Test @Test
public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception { public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception {
this.jwkSource = mock(JWKSource.class); this.jwkSource = mock(JWKSource.class);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,6 +15,12 @@
*/ */
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.Map;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2; import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
@ -24,11 +30,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.Map;
/** /**
* @author Joe Grandja * @author Joe Grandja
* @author Daniel Garnier-Moiroux * @author Daniel Garnier-Moiroux
@ -63,6 +64,8 @@ public class TestOAuth2Authorizations {
.principalName("principal") .principalName("principal")
.tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build()) .tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build())
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL,
new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"))
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()); .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes());
} }
} }

View File

@ -15,10 +15,17 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Set;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@ -28,11 +35,12 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@ -41,14 +49,10 @@ import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.Assertions.entry;
@ -69,6 +73,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
private static final String AUTHORIZATION_CODE = "code"; private static final String AUTHORIZATION_CODE = "code";
private OAuth2AuthorizationService authorizationService; private OAuth2AuthorizationService authorizationService;
private JwtEncoder jwtEncoder; private JwtEncoder jwtEncoder;
private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider;
@Before @Before
@ -77,6 +82,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
this.jwtEncoder = mock(JwtEncoder.class); this.jwtEncoder = mock(JwtEncoder.class);
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
this.authorizationService, this.jwtEncoder); this.authorizationService, this.jwtEncoder);
this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
} }
@Test @Test
@ -93,6 +100,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
.hasMessage("jwtEncoder cannot be null"); .hasMessage("jwtEncoder cannot be null");
} }
@Test
public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("jwtCustomizer cannot be null");
}
@Test @Test
public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() { public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() {
assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue(); assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue();
@ -225,6 +239,18 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture());
JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue();
assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(jwtEncodingContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(authorization);
assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(jwtEncodingContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
assertThat(jwtEncodingContext.getHeaders()).isNotNull();
assertThat(jwtEncodingContext.getClaims()).isNotNull();
ArgumentCaptor<JwtClaimsSet> jwtClaimsSetCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class); ArgumentCaptor<JwtClaimsSet> jwtClaimsSetCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class);
verify(this.jwtEncoder).encode(any(), jwtClaimsSetCaptor.capture()); verify(this.jwtEncoder).encode(any(), jwtClaimsSetCaptor.capture());
JwtClaimsSet jwtClaimsSet = jwtClaimsSetCaptor.getValue(); JwtClaimsSet jwtClaimsSet = jwtClaimsSetCaptor.getValue();
@ -264,6 +290,29 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture());
// Access Token context
JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0);
assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(accessTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
assertThat(accessTokenContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(accessTokenContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
assertThat(accessTokenContext.getHeaders()).isNotNull();
assertThat(accessTokenContext.getClaims()).isNotNull();
// ID Token context
JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization);
assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(idTokenContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
assertThat(idTokenContext.getHeaders()).isNotNull();
assertThat(idTokenContext.getClaims()).isNotNull();
verify(this.jwtEncoder, times(2)).encode(any(), any()); // Access token and ID Token verify(this.jwtEncoder, times(2)).encode(any(), any()); // Access token and ID Token
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);

View File

@ -15,28 +15,34 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.Set; import java.util.Set;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -53,6 +59,7 @@ import static org.mockito.Mockito.when;
public class OAuth2ClientCredentialsAuthenticationProviderTests { public class OAuth2ClientCredentialsAuthenticationProviderTests {
private OAuth2AuthorizationService authorizationService; private OAuth2AuthorizationService authorizationService;
private JwtEncoder jwtEncoder; private JwtEncoder jwtEncoder;
private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider; private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider;
@Before @Before
@ -61,6 +68,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
this.jwtEncoder = mock(JwtEncoder.class); this.jwtEncoder = mock(JwtEncoder.class);
this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider( this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider(
this.authorizationService, this.jwtEncoder); this.authorizationService, this.jwtEncoder);
this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
} }
@Test @Test
@ -77,6 +86,13 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
.hasMessage("jwtEncoder cannot be null"); .hasMessage("jwtEncoder cannot be null");
} }
@Test
public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("jwtCustomizer cannot be null");
}
@Test @Test
public void supportsWhenSupportedAuthenticationThenTrue() { public void supportsWhenSupportedAuthenticationThenTrue() {
assertThat(this.authenticationProvider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isTrue(); assertThat(this.authenticationProvider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isTrue();
@ -152,7 +168,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
OAuth2ClientCredentialsAuthenticationToken authentication = OAuth2ClientCredentialsAuthenticationToken authentication =
new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope); new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(requestedScope));
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -165,11 +181,23 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(registeredClient.getScopes()));
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture());
JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue();
assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(jwtEncodingContext.<Authentication>getPrincipal()).isEqualTo(clientPrincipal);
assertThat(jwtEncodingContext.getAuthorization()).isNull();
assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
assertThat(jwtEncodingContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
assertThat(jwtEncodingContext.getHeaders()).isNotNull();
assertThat(jwtEncodingContext.getClaims()).isNotNull();
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture()); verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization authorization = authorizationCaptor.getValue(); OAuth2Authorization authorization = authorizationCaptor.getValue();
@ -182,13 +210,14 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
} }
private static Jwt createJwt() { private static Jwt createJwt(Set<String> scope) {
Instant issuedAt = Instant.now(); Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
return Jwt.withTokenValue("token") return Jwt.withTokenValue("token")
.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
.issuedAt(issuedAt) .issuedAt(issuedAt)
.expiresAt(expiresAt) .expiresAt(expiresAt)
.claim(OAuth2ParameterNames.SCOPE, scope)
.build(); .build();
} }
} }

View File

@ -15,20 +15,30 @@
*/ */
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2; import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@ -36,14 +46,10 @@ import org.springframework.security.oauth2.server.authorization.TestOAuth2Author
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.HashSet;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -56,25 +62,24 @@ import static org.mockito.Mockito.when;
* Tests for {@link OAuth2RefreshTokenAuthenticationProvider}. * Tests for {@link OAuth2RefreshTokenAuthenticationProvider}.
* *
* @author Alexey Nesterov * @author Alexey Nesterov
* @author Joe Grandja
* @since 0.0.3 * @since 0.0.3
*/ */
public class OAuth2RefreshTokenAuthenticationProviderTests { public class OAuth2RefreshTokenAuthenticationProviderTests {
private OAuth2AuthorizationService authorizationService; private OAuth2AuthorizationService authorizationService;
private JwtEncoder jwtEncoder; private JwtEncoder jwtEncoder;
private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
private OAuth2RefreshTokenAuthenticationProvider authenticationProvider; private OAuth2RefreshTokenAuthenticationProvider authenticationProvider;
@Before @Before
public void setUp() { public void setUp() {
this.authorizationService = mock(OAuth2AuthorizationService.class); this.authorizationService = mock(OAuth2AuthorizationService.class);
this.jwtEncoder = mock(JwtEncoder.class); this.jwtEncoder = mock(JwtEncoder.class);
Jwt jwt = Jwt.withTokenValue("refreshed-access-token") when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(Collections.singleton("scope1")));
.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
.issuedAt(Instant.now())
.expiresAt(Instant.now().plus(1, ChronoUnit.HOURS))
.build();
when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt);
this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider( this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider(
this.authorizationService, this.jwtEncoder); this.authorizationService, this.jwtEncoder);
this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
} }
@Test @Test
@ -93,6 +98,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
.isEqualTo("jwtEncoder cannot be null"); .isEqualTo("jwtEncoder cannot be null");
} }
@Test
public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("jwtCustomizer cannot be null");
}
@Test @Test
public void supportsWhenSupportedAuthenticationThenTrue() { public void supportsWhenSupportedAuthenticationThenTrue() {
assertThat(this.authenticationProvider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue(); assertThat(this.authenticationProvider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue();
@ -119,6 +131,18 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture());
JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue();
assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(jwtEncodingContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(authorization);
assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
assertThat(jwtEncodingContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
assertThat(jwtEncodingContext.getHeaders()).isNotNull();
assertThat(jwtEncodingContext.getClaims()).isNotNull();
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture()); verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
@ -340,4 +364,15 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
.extracting("errorCode") .extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
} }
private static Jwt createJwt(Set<String> scope) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
return Jwt.withTokenValue("refreshed-access-token")
.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.claim(OAuth2ParameterNames.SCOPE, scope)
.build();
}
} }

View File

@ -0,0 +1,118 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import org.junit.Test;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.TestJoseHeaders;
import org.springframework.security.oauth2.jwt.TestJwtClaimsSets;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link JwtEncodingContext}.
*
* @author Joe Grandja
*/
public class JwtEncodingContextTests {
@Test
public void withWhenHeadersNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> JwtEncodingContext.with(null, TestJwtClaimsSets.jwtClaimsSet()))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("headersBuilder cannot be null");
}
@Test
public void withWhenClaimsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> JwtEncodingContext.with(TestJoseHeaders.joseHeader(), null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("claimsBuilder cannot be null");
}
@Test
public void setWhenValueNullThenThrowIllegalArgumentException() {
JwtEncodingContext.Builder builder = JwtEncodingContext
.with(TestJoseHeaders.joseHeader(), TestJwtClaimsSets.jwtClaimsSet());
assertThatThrownBy(() -> builder.registeredClient(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> builder.principal(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> builder.authorization(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> builder.tokenType(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> builder.authorizationGrantType(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> builder.authorizationGrant(null))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> builder.put(null, ""))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void buildWhenAllValuesProvidedThenAllValuesAreSet() {
JoseHeader.Builder headers = TestJoseHeaders.joseHeader();
JwtClaimsSet.Builder claims = TestJwtClaimsSets.jwtClaimsSet();
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "password");
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
OAuth2AuthorizationCodeAuthenticationToken authorizationGrant =
new OAuth2AuthorizationCodeAuthenticationToken(
"code", clientPrincipal, authorizationRequest.getRedirectUri(), null);
JwtEncodingContext context = JwtEncodingContext.with(headers, claims)
.registeredClient(registeredClient)
.principal(principal)
.authorization(authorization)
.tokenType(TokenType.ACCESS_TOKEN)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.authorizationGrant(authorizationGrant)
.put("custom-key-1", "custom-value-1")
.context(ctx -> ctx.put("custom-key-2", "custom-value-2"))
.build();
assertThat(context.getHeaders()).isEqualTo(headers);
assertThat(context.getClaims()).isEqualTo(claims);
assertThat(context.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(context.<Authentication>getPrincipal()).isEqualTo(principal);
assertThat(context.getAuthorization()).isEqualTo(authorization);
assertThat(context.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
assertThat(context.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(context.<Authentication>getAuthorizationGrant()).isEqualTo(authorizationGrant);
assertThat(context.<String>get("custom-key-1")).isEqualTo("custom-value-1");
assertThat(context.<String>get("custom-key-2")).isEqualTo("custom-value-2");
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,15 +15,25 @@
*/ */
package org.springframework.security.oauth2.server.authorization.web; package org.springframework.security.oauth2.server.authorization.web;
import java.nio.charset.StandardCharsets;
import java.util.Set;
import java.util.function.Consumer;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -44,13 +54,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.nio.charset.StandardCharsets;
import java.util.Set;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -464,6 +467,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = authorizationCaptor.getValue(); OAuth2Authorization authorization = authorizationCaptor.getValue();
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication);
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCode).isNotNull(); assertThat(authorizationCode).isNotNull();
@ -511,6 +516,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = authorizationCaptor.getValue(); OAuth2Authorization authorization = authorizationCaptor.getValue();
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication);
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCode).isNotNull(); assertThat(authorizationCode).isNotNull();
@ -556,6 +563,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = authorizationCaptor.getValue(); OAuth2Authorization authorization = authorizationCaptor.getValue();
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication);
String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE); String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE);
assertThat(state).isNotNull(); assertThat(state).isNotNull();