diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index a6f25b6..247a437 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -26,6 +26,7 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHt import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer; import org.springframework.security.crypto.key.CryptoKeySource; import org.springframework.security.oauth2.jose.jws.NimbusJwsEncoder; +import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider; @@ -166,7 +167,7 @@ public final class OAuth2AuthorizationServerConfigurer> RegisteredClientRepository getRegisteredClientRepository(B builder) { - RegisteredClientRepository registeredClientRepository = builder.getSharedObject(RegisteredClientRepository.class); - if (registeredClientRepository == null) { - registeredClientRepository = getRegisteredClientRepositoryBean(builder); - builder.setSharedObject(RegisteredClientRepository.class, registeredClientRepository); - } - return registeredClientRepository; - } - - private static > RegisteredClientRepository getRegisteredClientRepositoryBean(B builder) { - return builder.getSharedObject(ApplicationContext.class).getBean(RegisteredClientRepository.class); - } - - private static > OAuth2AuthorizationService getAuthorizationService(B builder) { - OAuth2AuthorizationService authorizationService = builder.getSharedObject(OAuth2AuthorizationService.class); - if (authorizationService == null) { - authorizationService = getAuthorizationServiceBean(builder); - if (authorizationService == null) { - authorizationService = new InMemoryOAuth2AuthorizationService(); - } - builder.setSharedObject(OAuth2AuthorizationService.class, authorizationService); - } - return authorizationService; - } - - private static > OAuth2AuthorizationService getAuthorizationServiceBean(B builder) { - Map authorizationServiceMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( - builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizationService.class); - if (authorizationServiceMap.size() > 1) { - throw new NoUniqueBeanDefinitionException(OAuth2AuthorizationService.class, authorizationServiceMap.size(), - "Expected single matching bean of type '" + OAuth2AuthorizationService.class.getName() + "' but found " + - authorizationServiceMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(authorizationServiceMap.keySet())); - } - return (!authorizationServiceMap.isEmpty() ? authorizationServiceMap.values().iterator().next() : null); - } - - private static > CryptoKeySource getKeySource(B builder) { - CryptoKeySource keySource = builder.getSharedObject(CryptoKeySource.class); - if (keySource == null) { - keySource = getKeySourceBean(builder); - builder.setSharedObject(CryptoKeySource.class, keySource); - } - return keySource; - } - - private static > CryptoKeySource getKeySourceBean(B builder) { - return builder.getSharedObject(ApplicationContext.class).getBean(CryptoKeySource.class); - } - - private static > ProviderSettings getProviderSettings(B builder) { - ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class); - if (providerSettings == null) { - providerSettings = getProviderSettingsBean(builder); - if (providerSettings == null) { - providerSettings = new ProviderSettings(); - } - builder.setSharedObject(ProviderSettings.class, providerSettings); - } - return providerSettings; - } - - private static > ProviderSettings getProviderSettingsBean(B builder) { - Map providerSettingsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( - builder.getSharedObject(ApplicationContext.class), ProviderSettings.class); - if (providerSettingsMap.size() > 1) { - throw new NoUniqueBeanDefinitionException(ProviderSettings.class, providerSettingsMap.size(), - "Expected single matching bean of type '" + ProviderSettings.class.getName() + "' but found " + - providerSettingsMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(providerSettingsMap.keySet())); - } - return (!providerSettingsMap.isEmpty() ? providerSettingsMap.values().iterator().next() : null); - } - - private void validateProviderSettings(ProviderSettings providerSettings) { + private static void validateProviderSettings(ProviderSettings providerSettings) { if (providerSettings.issuer() != null) { try { new URI(providerSettings.issuer()).toURL(); @@ -334,4 +263,70 @@ public final class OAuth2AuthorizationServerConfigurer> RegisteredClientRepository getRegisteredClientRepository(B builder) { + RegisteredClientRepository registeredClientRepository = builder.getSharedObject(RegisteredClientRepository.class); + if (registeredClientRepository == null) { + registeredClientRepository = getBean(builder, RegisteredClientRepository.class); + builder.setSharedObject(RegisteredClientRepository.class, registeredClientRepository); + } + return registeredClientRepository; + } + + private static > OAuth2AuthorizationService getAuthorizationService(B builder) { + OAuth2AuthorizationService authorizationService = builder.getSharedObject(OAuth2AuthorizationService.class); + if (authorizationService == null) { + authorizationService = getOptionalBean(builder, OAuth2AuthorizationService.class); + if (authorizationService == null) { + authorizationService = new InMemoryOAuth2AuthorizationService(); + } + builder.setSharedObject(OAuth2AuthorizationService.class, authorizationService); + } + return authorizationService; + } + + private static > JwtEncoder getJwtEncoder(B builder) { + JwtEncoder jwtEncoder = getOptionalBean(builder, JwtEncoder.class); + if (jwtEncoder == null) { + CryptoKeySource keySource = getKeySource(builder); + jwtEncoder = new NimbusJwsEncoder(keySource); + } + return jwtEncoder; + } + + private static > CryptoKeySource getKeySource(B builder) { + CryptoKeySource keySource = builder.getSharedObject(CryptoKeySource.class); + if (keySource == null) { + keySource = getBean(builder, CryptoKeySource.class); + builder.setSharedObject(CryptoKeySource.class, keySource); + } + return keySource; + } + + private static > ProviderSettings getProviderSettings(B builder) { + ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class); + if (providerSettings == null) { + providerSettings = getOptionalBean(builder, ProviderSettings.class); + if (providerSettings == null) { + providerSettings = new ProviderSettings(); + } + builder.setSharedObject(ProviderSettings.class, providerSettings); + } + return providerSettings; + } + + private static , T> T getBean(B builder, Class type) { + return builder.getSharedObject(ApplicationContext.class).getBean(type); + } + + private static , T> T getOptionalBean(B builder, Class type) { + Map beansMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( + builder.getSharedObject(ApplicationContext.class), type); + if (beansMap.size() > 1) { + throw new NoUniqueBeanDefinitionException(type, beansMap.size(), + "Expected single matching bean of type '" + type.getName() + "' but found " + + beansMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(beansMap.keySet())); + } + return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null); + } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java index d3d5899..943e24d 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java @@ -53,6 +53,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; +import java.util.function.BiConsumer; import java.util.stream.Collectors; /** @@ -94,6 +95,7 @@ public final class NimbusJwsEncoder implements JwtEncoder { private static final Converter jwsHeaderConverter = new JwsHeaderConverter(); private static final Converter jwtClaimsSetConverter = new JwtClaimsSetConverter(); private final CryptoKeySource keySource; + private BiConsumer jwtCustomizer = (headers, claims) -> {}; /** * Constructs a {@code NimbusJwsEncoder} using the provided parameters. @@ -105,6 +107,19 @@ public final class NimbusJwsEncoder implements JwtEncoder { this.keySource = keySource; } + /** + * Sets the {@link Jwt} customizer to be provided the + * {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder} + * allowing for further customizations. + * + * @param jwtCustomizer the {@link Jwt} customizer to be provided the + * {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder} + */ + public void setJwtCustomizer(BiConsumer jwtCustomizer) { + Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); + this.jwtCustomizer = jwtCustomizer; + } + @Override public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { Assert.notNull(headers, "headers cannot be null"); @@ -136,15 +151,18 @@ public final class NimbusJwsEncoder implements JwtEncoder { } } - headers = JoseHeader.from(headers) + JoseHeader.Builder headersBuilder = JoseHeader.from(headers) .type(JOSEObjectType.JWT.getType()) - .keyId(cryptoKey.getId()) - .build(); - JWSHeader jwsHeader = jwsHeaderConverter.convert(headers); + .keyId(cryptoKey.getId()); + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims) + .id(UUID.randomUUID().toString()); - claims = JwtClaimsSet.from(claims) - .id(UUID.randomUUID().toString()) - .build(); + this.jwtCustomizer.accept(headersBuilder, claimsBuilder); + + headers = headersBuilder.build(); + claims = claimsBuilder.build(); + + JWSHeader jwsHeader = jwsHeaderConverter.convert(headers); JWTClaimsSet jwtClaimsSet = jwtClaimsSetConverter.convert(claims); SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaimsSet); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index cf4588a..8d83a8c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -33,6 +33,10 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.jose.JoseHeader; +import org.springframework.security.oauth2.jose.jws.NimbusJwsEncoder; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; @@ -53,6 +57,7 @@ import org.springframework.util.StringUtils; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.Base64; +import java.util.function.BiConsumer; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; @@ -86,6 +91,8 @@ public class OAuth2AuthorizationCodeGrantTests { private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static CryptoKeySource keySource; + private static NimbusJwsEncoder jwtEncoder; + private static BiConsumer jwtCustomizer; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -98,6 +105,9 @@ public class OAuth2AuthorizationCodeGrantTests { registeredClientRepository = mock(RegisteredClientRepository.class); authorizationService = mock(OAuth2AuthorizationService.class); keySource = new StaticKeyGeneratingCryptoKeySource(); + jwtEncoder = new NimbusJwsEncoder(keySource); + jwtCustomizer = mock(BiConsumer.class); + jwtEncoder.setJwtCustomizer(jwtCustomizer); } @Before @@ -223,6 +233,28 @@ public class OAuth2AuthorizationCodeGrantTests { verify(authorizationService, times(2)).save(any()); } + @Test + public void requestWhenCustomJwtEncoderThenUsed() throws Exception { + this.spring.register(AuthorizationServerConfigurationWithJwtEncoder.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(authorizationService.findByToken( + eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), + eq(TokenType.AUTHORIZATION_CODE))) + .thenReturn(authorization); + + this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorization)) + .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( + registeredClient.getClientId(), registeredClient.getClientSecret()))); + + verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class)); + } + private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); @@ -270,4 +302,14 @@ public class OAuth2AuthorizationCodeGrantTests { return keySource; } } + + @EnableWebSecurity + @Import(OAuth2AuthorizationServerConfiguration.class) + static class AuthorizationServerConfigurationWithJwtEncoder extends AuthorizationServerConfiguration { + + @Bean + JwtEncoder jwtEncoder() { + return jwtEncoder; + } + } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java index 9822e61..dcc9595 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java @@ -32,11 +32,14 @@ import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; import java.security.interfaces.RSAPublicKey; import java.util.Collections; import java.util.LinkedHashSet; +import java.util.function.BiConsumer; import java.util.stream.Collectors; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -61,6 +64,13 @@ public class NimbusJwsEncoderTests { .hasMessage("keySource cannot be null"); } + @Test + public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.jwtEncoder.setJwtCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtCustomizer cannot be null"); + } + @Test public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); @@ -128,6 +138,24 @@ public class NimbusJwsEncoderTests { jwtDecoder.decode(jws.getTokenValue()); } + @Test + public void encodeWhenCustomizerSetThenCalled() { + AsymmetricKey rsaKey = TestCryptoKeys.rsaKey().build(); + when(this.keySource.getKeys()).thenReturn(Collections.singleton(rsaKey)); + + JoseHeader joseHeader = TestJoseHeaders.joseHeader() + .headers(headers -> headers.remove(JoseHeaderNames.CRIT)) + .build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + BiConsumer jwtCustomizer = mock(BiConsumer.class); + this.jwtEncoder.setJwtCustomizer(jwtCustomizer); + + this.jwtEncoder.encode(joseHeader, jwtClaimsSet); + + verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class)); + } + @Test public void encodeWhenMultipleActiveKeysThenUseFirst() { AsymmetricKey rsaKey1 = TestCryptoKeys.rsaKey().build();