From edefabdc6b7f531bf3ce7e3974a0e054517a8cc2 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 16 Jul 2020 13:59:55 -0400 Subject: [PATCH] Introduce JwtEncoder with JWS implementation Closes gh-81 --- .../OAuth2AuthorizationServerConfigurer.java | 35 +- .../OAuth2AuthorizationCodeGrantTests.java | 9 + .../OAuth2ClientCredentialsGrantTests.java | 9 + crypto/spring-security-crypto2.gradle | 11 + .../crypto/keys/KeyGeneratorUtils.java | 81 ++++ .../security/crypto/keys/KeyManager.java | 58 +++ .../security/crypto/keys/ManagedKey.java | 246 ++++++++++++ .../keys/StaticKeyGeneratingKeyManager.java | 89 +++++ .../security/crypto/keys/ManagedKeyTests.java | 120 ++++++ .../security/crypto/keys/TestManagedKeys.java | 50 +++ jose/spring-security-oauth2-jose2.gradle | 14 + .../security/oauth2/jose/JoseHeader.java | 368 ++++++++++++++++++ .../security/oauth2/jose/JoseHeaderNames.java | 96 +++++ .../oauth2/jose/jws/NimbusJwsEncoder.java | 325 ++++++++++++++++ .../security/oauth2/jwt/JwtClaimsSet.java | 198 ++++++++++ .../security/oauth2/jwt/JwtEncoder.java | 56 +++ .../oauth2/jwt/JwtEncodingException.java | 46 +++ .../security/oauth2/jose/JoseHeaderTests.java | 107 +++++ .../security/oauth2/jose/TestJoseHeaders.java | 57 +++ .../jose/jws/NimbusJwsEncoderTests.java | 159 ++++++++ .../oauth2/jwt/JwtClaimsSetTests.java | 90 +++++ .../oauth2/jwt/TestJwtClaimsSets.java | 50 +++ ...ecurity-oauth2-authorization-server.gradle | 1 + .../OAuth2AuthorizationAttributeNames.java | 6 + ...thorizationCodeAuthenticationProvider.java | 48 ++- ...ientCredentialsAuthenticationProvider.java | 50 ++- ...zationCodeAuthenticationProviderTests.java | 32 +- ...redentialsAuthenticationProviderTests.java | 36 +- 28 files changed, 2424 insertions(+), 23 deletions(-) create mode 100644 crypto/spring-security-crypto2.gradle create mode 100644 crypto/src/main/java/org/springframework/security/crypto/keys/KeyGeneratorUtils.java create mode 100644 crypto/src/main/java/org/springframework/security/crypto/keys/KeyManager.java create mode 100644 crypto/src/main/java/org/springframework/security/crypto/keys/ManagedKey.java create mode 100644 crypto/src/main/java/org/springframework/security/crypto/keys/StaticKeyGeneratingKeyManager.java create mode 100644 crypto/src/test/java/org/springframework/security/crypto/keys/ManagedKeyTests.java create mode 100644 crypto/src/test/java/org/springframework/security/crypto/keys/TestManagedKeys.java create mode 100644 jose/spring-security-oauth2-jose2.gradle create mode 100644 jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeader.java create mode 100644 jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeaderNames.java create mode 100644 jose/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java create mode 100644 jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java create mode 100644 jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java create mode 100644 jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncodingException.java create mode 100644 jose/src/test/java/org/springframework/security/oauth2/jose/JoseHeaderTests.java create mode 100644 jose/src/test/java/org/springframework/security/oauth2/jose/TestJoseHeaders.java create mode 100644 jose/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java create mode 100644 jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java create mode 100644 jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwtClaimsSets.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index 1e01e87..3ceb350 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -24,6 +24,8 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.oauth2.jose.jws.NimbusJwsEncoder; import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider; @@ -81,6 +83,18 @@ public final class OAuth2AuthorizationServerConfigurer keyManager(KeyManager keyManager) { + Assert.notNull(keyManager, "keyManager cannot be null"); + this.getBuilder().setSharedObject(KeyManager.class, keyManager); + return this; + } + @Override public void init(B builder) { OAuth2ClientAuthenticationProvider clientAuthenticationProvider = @@ -88,15 +102,19 @@ public final class OAuth2AuthorizationServerConfigurer exceptionHandling = builder.getConfigurer(ExceptionHandlingConfigurer.class); @@ -168,4 +186,17 @@ public final class OAuth2AuthorizationServerConfigurer> KeyManager getKeyManager(B builder) { + KeyManager keyManager = builder.getSharedObject(KeyManager.class); + if (keyManager == null) { + keyManager = getKeyManagerBean(builder); + builder.setSharedObject(KeyManager.class, keyManager); + } + return keyManager; + } + + private static > KeyManager getKeyManagerBean(B builder) { + return builder.getSharedObject(ApplicationContext.class).getBean(KeyManager.class); + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 6207ad0..32ff7a0 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -26,6 +26,8 @@ import org.springframework.http.HttpHeaders; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -73,6 +75,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. public class OAuth2AuthorizationCodeGrantTests { private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; + private static KeyManager keyManager; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -84,6 +87,7 @@ public class OAuth2AuthorizationCodeGrantTests { public static void init() { registeredClientRepository = mock(RegisteredClientRepository.class); authorizationService = mock(OAuth2AuthorizationService.class); + keyManager = new StaticKeyGeneratingKeyManager(); } @Before @@ -200,5 +204,10 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2AuthorizationService authorizationService() { return authorizationService; } + + @Bean + KeyManager keyManager() { + return keyManager; + } } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java index 97d039a..8c4c867 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java @@ -26,6 +26,8 @@ import org.springframework.http.HttpHeaders; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -60,6 +62,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. public class OAuth2ClientCredentialsGrantTests { private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; + private static KeyManager keyManager; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -71,6 +74,7 @@ public class OAuth2ClientCredentialsGrantTests { public static void init() { registeredClientRepository = mock(RegisteredClientRepository.class); authorizationService = mock(OAuth2AuthorizationService.class); + keyManager = new StaticKeyGeneratingKeyManager(); } @Before @@ -135,5 +139,10 @@ public class OAuth2ClientCredentialsGrantTests { OAuth2AuthorizationService authorizationService() { return authorizationService; } + + @Bean + KeyManager keyManager() { + return keyManager; + } } } diff --git a/crypto/spring-security-crypto2.gradle b/crypto/spring-security-crypto2.gradle new file mode 100644 index 0000000..e95743d --- /dev/null +++ b/crypto/spring-security-crypto2.gradle @@ -0,0 +1,11 @@ +apply plugin: 'io.spring.convention.spring-module' + +dependencies { + compile project(':spring-security-core2') + compile 'org.springframework.security:spring-security-core' + compile springCoreDependency + + testCompile 'junit:junit' + testCompile 'org.assertj:assertj-core' + testCompile 'org.mockito:mockito-core' +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/keys/KeyGeneratorUtils.java b/crypto/src/main/java/org/springframework/security/crypto/keys/KeyGeneratorUtils.java new file mode 100644 index 0000000..8822acb --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/keys/KeyGeneratorUtils.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.crypto.keys; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.spec.ECFieldFp; +import java.security.spec.ECParameterSpec; +import java.security.spec.ECPoint; +import java.security.spec.EllipticCurve; + +/** + * @author Joe Grandja + * @since 0.0.1 + */ +final class KeyGeneratorUtils { + + static SecretKey generateSecretKey() { + SecretKey hmacKey; + try { + hmacKey = KeyGenerator.getInstance("HmacSha256").generateKey(); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + return hmacKey; + } + + static KeyPair generateRsaKey() { + KeyPair keyPair; + try { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + keyPair = keyPairGenerator.generateKeyPair(); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + return keyPair; + } + + static KeyPair generateEcKey() { + EllipticCurve ellipticCurve = new EllipticCurve( + new ECFieldFp( + new BigInteger("115792089210356248762697446949407573530086143415290314195533631308867097853951")), + new BigInteger("115792089210356248762697446949407573530086143415290314195533631308867097853948"), + new BigInteger("41058363725152142129326129780047268409114441015993725554835256314039467401291")); + ECPoint ecPoint = new ECPoint( + new BigInteger("48439561293906451759052585252797914202762949526041747995844080717082404635286"), + new BigInteger("36134250956749795798585127919587881956611106672985015071877198253568414405109")); + ECParameterSpec ecParameterSpec = new ECParameterSpec( + ellipticCurve, + ecPoint, + new BigInteger("115792089210356248762697446949407573529996955224135760342422259061068512044369"), + 1); + + KeyPair keyPair; + try { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC"); + keyPairGenerator.initialize(ecParameterSpec); + keyPair = keyPairGenerator.generateKeyPair(); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + return keyPair; + } +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/keys/KeyManager.java b/crypto/src/main/java/org/springframework/security/crypto/keys/KeyManager.java new file mode 100644 index 0000000..5d16f62 --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/keys/KeyManager.java @@ -0,0 +1,58 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.crypto.keys; + +import org.springframework.lang.Nullable; + +import java.util.Set; + +/** + * Implementations of this interface are responsible for the management of {@link ManagedKey}(s), + * e.g. {@code javax.crypto.SecretKey}, {@code java.security.PrivateKey}, {@code java.security.PublicKey}, etc. + * + * @author Joe Grandja + * @since 0.0.1 + * @see ManagedKey + */ +public interface KeyManager { + + /** + * Returns the {@link ManagedKey} identified by the provided {@code keyId}, + * or {@code null} if not found. + * + * @param keyId the key ID + * @return the {@link ManagedKey}, or {@code null} if not found + */ + @Nullable + ManagedKey findByKeyId(String keyId); + + /** + * Returns a {@code Set} of {@link ManagedKey}(s) having the provided key {@code algorithm}, + * or an empty {@code Set} if not found. + * + * @param algorithm the key algorithm + * @return a {@code Set} of {@link ManagedKey}(s), or an empty {@code Set} if not found + */ + Set findByAlgorithm(String algorithm); + + /** + * Returns a {@code Set} of the {@link ManagedKey}(s). + * + * @return a {@code Set} of the {@link ManagedKey}(s) + */ + Set getKeys(); + +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/keys/ManagedKey.java b/crypto/src/main/java/org/springframework/security/crypto/keys/ManagedKey.java new file mode 100644 index 0000000..0b6e96b --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/keys/ManagedKey.java @@ -0,0 +1,246 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.crypto.keys; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.SpringSecurityCoreVersion2; +import org.springframework.util.Assert; + +import javax.crypto.SecretKey; +import java.io.Serializable; +import java.security.Key; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.time.Instant; +import java.util.Objects; + +/** + * A {@code java.security.Key} that is managed by a {@link KeyManager}. + * + * @author Joe Grandja + * @since 0.0.1 + * @see KeyManager + */ +public final class ManagedKey implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID; + private Key key; + private PublicKey publicKey; + private String keyId; + private Instant activatedOn; + private Instant deactivatedOn; + + private ManagedKey() { + } + + /** + * Returns {@code true} if this is a symmetric key, {@code false} otherwise. + * + * @return {@code true} if this is a symmetric key, {@code false} otherwise + */ + public boolean isSymmetric() { + return SecretKey.class.isAssignableFrom(this.key.getClass()); + } + + /** + * Returns {@code true} if this is a asymmetric key, {@code false} otherwise. + * + * @return {@code true} if this is a asymmetric key, {@code false} otherwise + */ + public boolean isAsymmetric() { + return PrivateKey.class.isAssignableFrom(this.key.getClass()); + } + + /** + * Returns a type of {@code java.security.Key}, + * e.g. {@code javax.crypto.SecretKey} or {@code java.security.PrivateKey}. + * + * @param the type of {@code java.security.Key} + * @return the type of {@code java.security.Key} + */ + @SuppressWarnings("unchecked") + public T getKey() { + return (T) this.key; + } + + /** + * Returns the {@code java.security.PublicKey} if this is a asymmetric key, {@code null} otherwise. + * + * @return the {@code java.security.PublicKey} if this is a asymmetric key, {@code null} otherwise + */ + @Nullable + public PublicKey getPublicKey() { + return this.publicKey; + } + + /** + * Returns the key ID. + * + * @return the key ID + */ + public String getKeyId() { + return this.keyId; + } + + /** + * Returns the time when this key was activated. + * + * @return the time when this key was activated + */ + public Instant getActivatedOn() { + return this.activatedOn; + } + + /** + * Returns the time when this key was deactivated, {@code null} if still active. + * + * @return the time when this key was deactivated, {@code null} if still active + */ + @Nullable + public Instant getDeactivatedOn() { + return this.deactivatedOn; + } + + /** + * Returns {@code true} if this key is active, {@code false} otherwise. + * + * @return {@code true} if this key is active, {@code false} otherwise + */ + public boolean isActive() { + return getDeactivatedOn() == null; + } + + /** + * Returns the key algorithm. + * + * @return the key algorithm + */ + public String getAlgorithm() { + return this.key.getAlgorithm(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ManagedKey that = (ManagedKey) obj; + return Objects.equals(this.keyId, that.keyId); + } + + @Override + public int hashCode() { + return Objects.hash(this.keyId); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@code javax.crypto.SecretKey}. + * + * @param secretKey the {@code javax.crypto.SecretKey} + * @return the {@link Builder} + */ + public static Builder withSymmetricKey(SecretKey secretKey) { + return new Builder(secretKey); + } + + /** + * Returns a new {@link Builder}, initialized with the provided + * {@code java.security.PublicKey} and {@code java.security.PrivateKey}. + * + * @param publicKey the {@code java.security.PublicKey} + * @param privateKey the {@code java.security.PrivateKey} + * @return the {@link Builder} + */ + public static Builder withAsymmetricKey(PublicKey publicKey, PrivateKey privateKey) { + return new Builder(publicKey, privateKey); + } + + /** + * A builder for {@link ManagedKey}. + */ + public static class Builder { + private Key key; + private PublicKey publicKey; + private String keyId; + private Instant activatedOn; + private Instant deactivatedOn; + + private Builder(SecretKey secretKey) { + Assert.notNull(secretKey, "secretKey cannot be null"); + this.key = secretKey; + } + + private Builder(PublicKey publicKey, PrivateKey privateKey) { + Assert.notNull(publicKey, "publicKey cannot be null"); + Assert.notNull(privateKey, "privateKey cannot be null"); + this.key = privateKey; + this.publicKey = publicKey; + } + + /** + * Sets the key ID. + * + * @param keyId the key ID + * @return the {@link Builder} + */ + public Builder keyId(String keyId) { + this.keyId = keyId; + return this; + } + + /** + * Sets the time when this key was activated. + * + * @param activatedOn the time when this key was activated + * @return the {@link Builder} + */ + public Builder activatedOn(Instant activatedOn) { + this.activatedOn = activatedOn; + return this; + } + + /** + * Sets the time when this key was deactivated. + * + * @param deactivatedOn the time when this key was deactivated + * @return the {@link Builder} + */ + public Builder deactivatedOn(Instant deactivatedOn) { + this.deactivatedOn = deactivatedOn; + return this; + } + + /** + * Builds a new {@link ManagedKey}. + * + * @return a {@link ManagedKey} + */ + public ManagedKey build() { + Assert.hasText(this.keyId, "keyId cannot be empty"); + Assert.notNull(this.activatedOn, "activatedOn cannot be null"); + + ManagedKey managedKey = new ManagedKey(); + managedKey.key = this.key; + managedKey.publicKey = this.publicKey; + managedKey.keyId = this.keyId; + managedKey.activatedOn = this.activatedOn; + managedKey.deactivatedOn = this.deactivatedOn; + return managedKey; + } + } +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/keys/StaticKeyGeneratingKeyManager.java b/crypto/src/main/java/org/springframework/security/crypto/keys/StaticKeyGeneratingKeyManager.java new file mode 100644 index 0000000..23ec759 --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/keys/StaticKeyGeneratingKeyManager.java @@ -0,0 +1,89 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.crypto.keys; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import javax.crypto.SecretKey; +import java.security.KeyPair; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.springframework.security.crypto.keys.KeyGeneratorUtils.generateRsaKey; +import static org.springframework.security.crypto.keys.KeyGeneratorUtils.generateSecretKey; + +/** + * An implementation of a {@link KeyManager} that generates the {@link ManagedKey}(s) when constructed. + * + *

+ * NOTE: This implementation should ONLY be used during development/testing. + * + * @author Joe Grandja + * @since 0.0.1 + * @see KeyManager + */ +public final class StaticKeyGeneratingKeyManager implements KeyManager { + private final Map keys; + + public StaticKeyGeneratingKeyManager() { + this.keys = Collections.unmodifiableMap(new HashMap<>(generateKeys())); + } + + @Nullable + @Override + public ManagedKey findByKeyId(String keyId) { + Assert.hasText(keyId, "keyId cannot be empty"); + return this.keys.get(keyId); + } + + @Override + public Set findByAlgorithm(String algorithm) { + Assert.hasText(algorithm, "algorithm cannot be empty"); + return this.keys.values().stream() + .filter(managedKey -> managedKey.getAlgorithm().equals(algorithm)) + .collect(Collectors.toSet()); + } + + @Override + public Set getKeys() { + return new HashSet<>(this.keys.values()); + } + + private static Map generateKeys() { + KeyPair rsaKeyPair = generateRsaKey(); + ManagedKey rsaManagedKey = ManagedKey.withAsymmetricKey(rsaKeyPair.getPublic(), rsaKeyPair.getPrivate()) + .keyId(UUID.randomUUID().toString()) + .activatedOn(Instant.now()) + .build(); + + SecretKey hmacKey = generateSecretKey(); + ManagedKey secretManagedKey = ManagedKey.withSymmetricKey(hmacKey) + .keyId(UUID.randomUUID().toString()) + .activatedOn(Instant.now()) + .build(); + + return Stream.of(rsaManagedKey, secretManagedKey) + .collect(Collectors.toMap(ManagedKey::getKeyId, v -> v)); + } +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/keys/ManagedKeyTests.java b/crypto/src/test/java/org/springframework/security/crypto/keys/ManagedKeyTests.java new file mode 100644 index 0000000..100cc04 --- /dev/null +++ b/crypto/src/test/java/org/springframework/security/crypto/keys/ManagedKeyTests.java @@ -0,0 +1,120 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.crypto.keys; + +import org.junit.BeforeClass; +import org.junit.Test; + +import javax.crypto.SecretKey; +import java.security.Key; +import java.security.KeyPair; +import java.security.PrivateKey; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.security.crypto.keys.KeyGeneratorUtils.generateRsaKey; +import static org.springframework.security.crypto.keys.KeyGeneratorUtils.generateSecretKey; + +/** + * Tests for {@link ManagedKey}. + * + * @author Joe Grandja + */ +public class ManagedKeyTests { + private static SecretKey secretKey; + private static KeyPair rsaKeyPair; + + @BeforeClass + public static void init() { + secretKey = generateSecretKey(); + rsaKeyPair = generateRsaKey(); + } + + @Test + public void withSymmetricKeyWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> ManagedKey.withSymmetricKey(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("secretKey cannot be null"); + } + + @Test + public void buildWhenKeyIdNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> ManagedKey.withSymmetricKey(secretKey).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("keyId cannot be empty"); + } + + @Test + public void buildWhenActivatedOnNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> ManagedKey.withSymmetricKey(secretKey).keyId("keyId").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("activatedOn cannot be null"); + } + + @Test + public void buildWhenSymmetricKeyAllAttributesProvidedThenAllAttributesAreSet() { + ManagedKey expectedManagedKey = TestManagedKeys.secretManagedKey().build(); + + ManagedKey managedKey = ManagedKey.withSymmetricKey(expectedManagedKey.getKey()) + .keyId(expectedManagedKey.getKeyId()) + .activatedOn(expectedManagedKey.getActivatedOn()) + .build(); + + assertThat(managedKey.isSymmetric()).isTrue(); + assertThat(managedKey.getKey()).isInstanceOf(SecretKey.class); + assertThat(managedKey.getKey()).isEqualTo(expectedManagedKey.getKey()); + assertThat(managedKey.getPublicKey()).isNull(); + assertThat(managedKey.getKeyId()).isEqualTo(expectedManagedKey.getKeyId()); + assertThat(managedKey.getActivatedOn()).isEqualTo(expectedManagedKey.getActivatedOn()); + assertThat(managedKey.getDeactivatedOn()).isEqualTo(expectedManagedKey.getDeactivatedOn()); + assertThat(managedKey.isActive()).isTrue(); + assertThat(managedKey.getAlgorithm()).isEqualTo(expectedManagedKey.getAlgorithm()); + } + + @Test + public void withAsymmetricKeyWhenPublicKeyNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> ManagedKey.withAsymmetricKey(null, rsaKeyPair.getPrivate())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("publicKey cannot be null"); + } + + @Test + public void withAsymmetricKeyWhenPrivateKeyNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> ManagedKey.withAsymmetricKey(rsaKeyPair.getPublic(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("privateKey cannot be null"); + } + + @Test + public void buildWhenAsymmetricKeyAllAttributesProvidedThenAllAttributesAreSet() { + ManagedKey expectedManagedKey = TestManagedKeys.rsaManagedKey().build(); + + ManagedKey managedKey = ManagedKey.withAsymmetricKey(expectedManagedKey.getPublicKey(), expectedManagedKey.getKey()) + .keyId(expectedManagedKey.getKeyId()) + .activatedOn(expectedManagedKey.getActivatedOn()) + .build(); + + assertThat(managedKey.isAsymmetric()).isTrue(); + assertThat(managedKey.getKey()).isInstanceOf(PrivateKey.class); + assertThat(managedKey.getKey()).isEqualTo(expectedManagedKey.getKey()); + assertThat(managedKey.getPublicKey()).isNotNull(); + assertThat(managedKey.getKeyId()).isEqualTo(expectedManagedKey.getKeyId()); + assertThat(managedKey.getActivatedOn()).isEqualTo(expectedManagedKey.getActivatedOn()); + assertThat(managedKey.getDeactivatedOn()).isEqualTo(expectedManagedKey.getDeactivatedOn()); + assertThat(managedKey.isActive()).isTrue(); + assertThat(managedKey.getAlgorithm()).isEqualTo(expectedManagedKey.getAlgorithm()); + } +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/keys/TestManagedKeys.java b/crypto/src/test/java/org/springframework/security/crypto/keys/TestManagedKeys.java new file mode 100644 index 0000000..4ab92cf --- /dev/null +++ b/crypto/src/test/java/org/springframework/security/crypto/keys/TestManagedKeys.java @@ -0,0 +1,50 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.crypto.keys; + +import java.security.KeyPair; +import java.time.Instant; +import java.util.UUID; + +import static org.springframework.security.crypto.keys.KeyGeneratorUtils.generateEcKey; +import static org.springframework.security.crypto.keys.KeyGeneratorUtils.generateRsaKey; +import static org.springframework.security.crypto.keys.KeyGeneratorUtils.generateSecretKey; + +/** + * @author Joe Grandja + */ +public class TestManagedKeys { + + public static ManagedKey.Builder secretManagedKey() { + return ManagedKey.withSymmetricKey(generateSecretKey()) + .keyId(UUID.randomUUID().toString()) + .activatedOn(Instant.now()); + } + + public static ManagedKey.Builder rsaManagedKey() { + KeyPair rsaKeyPair = generateRsaKey(); + return ManagedKey.withAsymmetricKey(rsaKeyPair.getPublic(), rsaKeyPair.getPrivate()) + .keyId(UUID.randomUUID().toString()) + .activatedOn(Instant.now()); + } + + public static ManagedKey.Builder ecManagedKey() { + KeyPair ecKeyPair = generateEcKey(); + return ManagedKey.withAsymmetricKey(ecKeyPair.getPublic(), ecKeyPair.getPrivate()) + .keyId(UUID.randomUUID().toString()) + .activatedOn(Instant.now()); + } +} diff --git a/jose/spring-security-oauth2-jose2.gradle b/jose/spring-security-oauth2-jose2.gradle new file mode 100644 index 0000000..6ab2ad0 --- /dev/null +++ b/jose/spring-security-oauth2-jose2.gradle @@ -0,0 +1,14 @@ +apply plugin: 'io.spring.convention.spring-module' + +dependencies { + compile project(':spring-security-crypto2') + compile 'org.springframework.security:spring-security-oauth2-core' + compile 'org.springframework.security:spring-security-oauth2-jose' + compile springCoreDependency + compile 'com.nimbusds:nimbus-jose-jwt' + + testCompile project(path: ':spring-security-crypto2', configuration: 'tests') + testCompile 'junit:junit' + testCompile 'org.assertj:assertj-core' + testCompile 'org.mockito:mockito-core' +} diff --git a/jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeader.java b/jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeader.java new file mode 100644 index 0000000..993f861 --- /dev/null +++ b/jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeader.java @@ -0,0 +1,368 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jose; + +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.Assert; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import static org.springframework.security.oauth2.jose.JoseHeaderNames.ALG; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.CRIT; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.CTY; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.JKU; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.JWK; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.KID; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.TYP; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.X5C; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.X5T; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.X5T_S256; +import static org.springframework.security.oauth2.jose.JoseHeaderNames.X5U; + +/** + * The JOSE header is a JSON object representing the header parameters of a JSON Web Token, + * whether the JWT is a JWS or JWE, that describe the cryptographic operations applied to the JWT + * and optionally, additional properties of the JWT. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 0.0.1 + * @see Jwt + * @see JWT JOSE Header + * @see JWS JOSE Header + * @see JWE JOSE Header + */ +public final class JoseHeader { + private final Map headers; + + private JoseHeader(Map headers) { + this.headers = Collections.unmodifiableMap(new LinkedHashMap<>(headers)); + } + + /** + * Returns the JWS algorithm used to digitally sign the JWS. + * + * @return the JWS algorithm + */ + public JwsAlgorithm getJwsAlgorithm() { + return getHeader(ALG); + } + + /** + * Returns the JWK Set URL that refers to the resource of a set of JSON-encoded public keys, + * one of which corresponds to the key used to digitally sign the JWS or encrypt the JWE. + * + * @return the JWK Set URL + */ + public String getJwkSetUri() { + return getHeader(JKU); + } + + /** + * Returns the JSON Web Key which is the public key that corresponds to the key + * used to digitally sign the JWS or encrypt the JWE. + * + * @return the JSON Web Key + */ + public Map getJwk() { + return getHeader(JWK); + } + + /** + * Returns the key ID that is a hint indicating which key was used to secure the JWS or JWE. + * + * @return the key ID + */ + public String getKeyId() { + return getHeader(KID); + } + + /** + * Returns the X.509 URL that refers to the resource for the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @return the X.509 URL + */ + public String getX509Uri() { + return getHeader(X5U); + } + + /** + * Returns the X.509 certificate chain that contains the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @return the X.509 certificate chain + */ + public List getX509CertificateChain() { + return getHeader(X5C); + } + + /** + * Returns the X.509 certificate SHA-1 thumbprint that is a base64url-encoded SHA-1 thumbprint (a.k.a. digest) + * of the DER encoding of the X.509 certificate corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @return the X.509 certificate SHA-1 thumbprint + */ + public String getX509SHA1Thumbprint() { + return getHeader(X5T); + } + + /** + * Returns the X.509 certificate SHA-256 thumbprint that is a base64url-encoded SHA-256 thumbprint (a.k.a. digest) + * of the DER encoding of the X.509 certificate corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @return the X.509 certificate SHA-256 thumbprint + */ + public String getX509SHA256Thumbprint() { + return getHeader(X5T_S256); + } + + /** + * Returns the critical headers that indicates which extensions to the JWS/JWE/JWA specifications + * are being used that MUST be understood and processed. + * + * @return the critical headers + */ + public Set getCritical() { + return getHeader(CRIT); + } + + /** + * Returns the type header that declares the media type of the JWS/JWE. + * + * @return the type header + */ + public String getType() { + return getHeader(TYP); + } + + /** + * Returns the content type header that declares the media type of the secured content (the payload). + * + * @return the content type header + */ + public String getContentType() { + return getHeader(CTY); + } + + /** + * Returns the headers. + * + * @return the headers + */ + public Map getHeaders() { + return this.headers; + } + + /** + * Returns the header value. + * + * @param name the header name + * @param the type of the header value + * @return the header value + */ + @SuppressWarnings("unchecked") + public T getHeader(String name) { + Assert.hasText(name, "name cannot be empty"); + return (T) getHeaders().get(name); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@link JwsAlgorithm}. + * + * @param jwsAlgorithm the {@link JwsAlgorithm} + * @return the {@link Builder} + */ + public static Builder withAlgorithm(JwsAlgorithm jwsAlgorithm) { + return new Builder(jwsAlgorithm); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@code headers}. + * + * @param headers the headers + * @return the {@link Builder} + */ + public static Builder from(JoseHeader headers) { + return new Builder(headers); + } + + /** + * A builder for {@link JoseHeader}. + */ + public static class Builder { + private final Map headers = new LinkedHashMap<>(); + + private Builder(JwsAlgorithm jwsAlgorithm) { + Assert.notNull(jwsAlgorithm, "jwsAlgorithm cannot be null"); + header(ALG, jwsAlgorithm); + } + + private Builder(JoseHeader headers) { + Assert.notNull(headers, "headers cannot be null"); + this.headers.putAll(headers.getHeaders()); + } + + /** + * Sets the JWK Set URL that refers to the resource of a set of JSON-encoded public keys, + * one of which corresponds to the key used to digitally sign the JWS or encrypt the JWE. + * + * @param jwkSetUri the JWK Set URL + * @return the {@link Builder} + */ + public Builder jwkSetUri(String jwkSetUri) { + return header(JKU, jwkSetUri); + } + + /** + * Sets the JSON Web Key which is the public key that corresponds to the key + * used to digitally sign the JWS or encrypt the JWE. + * + * @param jwk the JSON Web Key + * @return the {@link Builder} + */ + public Builder jwk(Map jwk) { + return header(JWK, jwk); + } + + /** + * Sets the key ID that is a hint indicating which key was used to secure the JWS or JWE. + * + * @param keyId the key ID + * @return the {@link Builder} + */ + public Builder keyId(String keyId) { + return header(KID, keyId); + } + + /** + * Sets the X.509 URL that refers to the resource for the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @param x509Uri the X.509 URL + * @return the {@link Builder} + */ + public Builder x509Uri(String x509Uri) { + return header(X5U, x509Uri); + } + + /** + * Sets the X.509 certificate chain that contains the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @param x509CertificateChain the X.509 certificate chain + * @return the {@link Builder} + */ + public Builder x509CertificateChain(List x509CertificateChain) { + return header(X5C, x509CertificateChain); + } + + /** + * Sets the X.509 certificate SHA-1 thumbprint that is a base64url-encoded SHA-1 thumbprint (a.k.a. digest) + * of the DER encoding of the X.509 certificate corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @param x509SHA1Thumbprint the X.509 certificate SHA-1 thumbprint + * @return the {@link Builder} + */ + public Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) { + return header(X5T, x509SHA1Thumbprint); + } + + /** + * Sets the X.509 certificate SHA-256 thumbprint that is a base64url-encoded SHA-256 thumbprint (a.k.a. digest) + * of the DER encoding of the X.509 certificate corresponding to the key used to digitally sign the JWS or encrypt the JWE. + * + * @param x509SHA256Thumbprint the X.509 certificate SHA-256 thumbprint + * @return the {@link Builder} + */ + public Builder x509SHA256Thumbprint(String x509SHA256Thumbprint) { + return header(X5T_S256, x509SHA256Thumbprint); + } + + /** + * Sets the critical headers that indicates which extensions to the JWS/JWE/JWA specifications + * are being used that MUST be understood and processed. + * + * @param headerNames the critical header names + * @return the {@link Builder} + */ + public Builder critical(Set headerNames) { + return header(CRIT, headerNames); + } + + /** + * Sets the type header that declares the media type of the JWS/JWE. + * + * @param type the type header + * @return the {@link Builder} + */ + public Builder type(String type) { + return header(TYP, type); + } + + /** + * Sets the content type header that declares the media type of the secured content (the payload). + * + * @param contentType the content type header + * @return the {@link Builder} + */ + public Builder contentType(String contentType) { + return header(CTY, contentType); + } + + /** + * Sets the header. + * + * @param name the header name + * @param value the header value + * @return the {@link Builder} + */ + public Builder header(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.headers.put(name, value); + return this; + } + + /** + * A {@code Consumer} to be provided access to the headers + * allowing the ability to add, replace, or remove. + * + * @param headersConsumer a {@code Consumer} of the headers + * @return the {@link Builder} + */ + public Builder headers(Consumer> headersConsumer) { + headersConsumer.accept(this.headers); + return this; + } + + /** + * Builds a new {@link JoseHeader}. + * + * @return a {@link JoseHeader} + */ + public JoseHeader build() { + Assert.notEmpty(this.headers, "headers cannot be empty"); + return new JoseHeader(this.headers); + } + } +} diff --git a/jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeaderNames.java b/jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeaderNames.java new file mode 100644 index 0000000..d259d99 --- /dev/null +++ b/jose/src/main/java/org/springframework/security/oauth2/jose/JoseHeaderNames.java @@ -0,0 +1,96 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jose; + +/** + * The Registered Header Parameter Names defined by the JSON Web Token (JWT), + * JSON Web Signature (JWS) and JSON Web Encryption (JWE) specifications + * that may be contained in the JOSE Header of a JWT. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 0.0.1 + * @see JoseHeader + * @see JWT JOSE Header + * @see JWS JOSE Header + * @see JWE JOSE Header + */ +public interface JoseHeaderNames { + + /** + * {@code alg} - the algorithm header identifies the cryptographic algorithm used to secure a JWS or JWE + */ + String ALG = "alg"; + + /** + * {@code jku} - the JWK Set URL header is a URI that refers to a resource for a set of JSON-encoded public keys, + * one of which corresponds to the key used to digitally sign a JWS or encrypt a JWE + */ + String JKU = "jku"; + + /** + * {@code jwk} - the JSON Web Key header is the public key that corresponds to the key + * used to digitally sign a JWS or encrypt a JWE + */ + String JWK = "jwk"; + + /** + * {@code kid} - the key ID header is a hint indicating which key was used to secure a JWS or JWE + */ + String KID = "kid"; + + /** + * {@code x5u} - the X.509 URL header is a URI that refers to a resource for the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign a JWS or encrypt a JWE + */ + String X5U = "x5u"; + + /** + * {@code x5c} - the X.509 certificate chain header contains the X.509 public key certificate + * or certificate chain corresponding to the key used to digitally sign a JWS or encrypt a JWE + */ + String X5C = "x5c"; + + /** + * {@code x5t} - the X.509 certificate SHA-1 thumbprint header is a base64url-encoded SHA-1 thumbprint (a.k.a. digest) + * of the DER encoding of the X.509 certificate corresponding to the key used to digitally sign a JWS or encrypt a JWE + */ + String X5T = "x5t"; + + /** + * {@code x5t#S256} - the X.509 certificate SHA-256 thumbprint header is a base64url-encoded SHA-256 thumbprint (a.k.a. digest) + * of the DER encoding of the X.509 certificate corresponding to the key used to digitally sign a JWS or encrypt a JWE + */ + String X5T_S256 = "x5t#S256"; + + /** + * {@code typ} - the type header is used by JWS/JWE applications to declare the media type of a JWS/JWE + */ + String TYP = "typ"; + + /** + * {@code cty} - the content type header is used by JWS/JWE applications to declare the media type + * of the secured content (the payload) + */ + String CTY = "cty"; + + /** + * {@code crit} - the critical header indicates that extensions to the JWS/JWE/JWA specifications + * are being used that MUST be understood and processed + */ + String CRIT = "crit"; + +} diff --git a/jose/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java b/jose/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java new file mode 100644 index 0000000..2eb6c97 --- /dev/null +++ b/jose/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java @@ -0,0 +1,325 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jose.jws; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JOSEObjectType; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.KeyLengthException; +import com.nimbusds.jose.crypto.MACSigner; +import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.util.Base64; +import com.nimbusds.jose.util.Base64URL; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import net.minidev.json.JSONObject; +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.ManagedKey; +import org.springframework.security.oauth2.jose.JoseHeader; +import org.springframework.security.oauth2.jose.JoseHeaderNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.jwt.JwtEncodingException; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import javax.crypto.SecretKey; +import java.net.URI; +import java.net.URL; +import java.security.PrivateKey; +import java.time.Instant; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +/** + * An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) + * using the JSON Web Signature (JWS) Compact Serialization format. + * The private/secret key used for signing the JWS is obtained + * from the {@link KeyManager} supplied via the constructor. + * + *

+ * NOTE: This implementation uses the Nimbus JOSE + JWT SDK. + * + * @author Joe Grandja + * @since 0.0.1 + * @see JwtEncoder + * @see KeyManager + * @see JSON Web Token (JWT) + * @see JSON Web Signature (JWS) + * @see JWS Compact Serialization + * @see Nimbus JOSE + JWT SDK + */ +public final class NimbusJwsEncoder implements JwtEncoder { + private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = + "An error occurred while attempting to encode the Jwt: %s"; + private static final String RSA_KEY_TYPE = "RSA"; + private static final String EC_KEY_TYPE = "EC"; + private static final Map jcaKeyAlgorithmMappings = new HashMap() { + { + put(MacAlgorithm.HS256, "HmacSHA256"); + put(MacAlgorithm.HS384, "HmacSHA384"); + put(MacAlgorithm.HS512, "HmacSHA512"); + put(SignatureAlgorithm.RS256, RSA_KEY_TYPE); + put(SignatureAlgorithm.RS384, RSA_KEY_TYPE); + put(SignatureAlgorithm.RS512, RSA_KEY_TYPE); + put(SignatureAlgorithm.ES256, EC_KEY_TYPE); + put(SignatureAlgorithm.ES384, EC_KEY_TYPE); + put(SignatureAlgorithm.ES512, EC_KEY_TYPE); + } + }; + private static final Converter jwsHeaderConverter = new JwsHeaderConverter(); + private static final Converter jwtClaimsSetConverter = new JwtClaimsSetConverter(); + private final KeyManager keyManager; + + /** + * Constructs a {@code NimbusJwsEncoder} using the provided parameters. + * + * @param keyManager the key manager + */ + public NimbusJwsEncoder(KeyManager keyManager) { + Assert.notNull(keyManager, "keyManager cannot be null"); + this.keyManager = keyManager; + } + + @Override + public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { + Assert.notNull(headers, "headers cannot be null"); + Assert.notNull(claims, "claims cannot be null"); + + ManagedKey managedKey = selectKey(headers); + if (managedKey == null) { + throw new JwtEncodingException(String.format( + ENCODING_ERROR_MESSAGE_TEMPLATE, + "Unsupported key for algorithm '" + headers.getJwsAlgorithm().getName() + "'")); + } + + JWSSigner jwsSigner; + if (managedKey.isAsymmetric()) { + if (!managedKey.getAlgorithm().equals(RSA_KEY_TYPE)) { + throw new JwtEncodingException(String.format( + ENCODING_ERROR_MESSAGE_TEMPLATE, + "Unsupported key type '" + managedKey.getAlgorithm() + "'")); + } + PrivateKey privateKey = managedKey.getKey(); + jwsSigner = new RSASSASigner(privateKey); + } else { + SecretKey secretKey = managedKey.getKey(); + try { + jwsSigner = new MACSigner(secretKey); + } catch (KeyLengthException ex) { + throw new JwtEncodingException(String.format( + ENCODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + } + } + + headers = JoseHeader.from(headers) + .type(JOSEObjectType.JWT.getType()) + .keyId(managedKey.getKeyId()) + .build(); + JWSHeader jwsHeader = jwsHeaderConverter.convert(headers); + + claims = JwtClaimsSet.from(claims) + .id(UUID.randomUUID().toString()) + .build(); + JWTClaimsSet jwtClaimsSet = jwtClaimsSetConverter.convert(claims); + + SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaimsSet); + try { + signedJWT.sign(jwsSigner); + } catch (JOSEException ex) { + throw new JwtEncodingException(String.format( + ENCODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + } + String jws = signedJWT.serialize(); + + return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), + headers.getHeaders(), claims.getClaims()); + } + + private ManagedKey selectKey(JoseHeader headers) { + JwsAlgorithm jwsAlgorithm = headers.getJwsAlgorithm(); + String keyAlgorithm = jcaKeyAlgorithmMappings.get(jwsAlgorithm); + if (!StringUtils.hasText(keyAlgorithm)) { + return null; + } + + Set matchingKeys = this.keyManager.findByAlgorithm(keyAlgorithm); + if (CollectionUtils.isEmpty(matchingKeys)) { + return null; + } + + return matchingKeys.stream() + .filter(ManagedKey::isActive) + .max(this::mostRecentActivated) + .orElse(null); + } + + private int mostRecentActivated(ManagedKey managedKey1, ManagedKey managedKey2) { + return managedKey1.getActivatedOn().isAfter(managedKey2.getActivatedOn()) ? 1 : -1; + } + + private static class JwsHeaderConverter implements Converter { + + @Override + public JWSHeader convert(JoseHeader headers) { + JWSHeader.Builder builder = new JWSHeader.Builder( + JWSAlgorithm.parse(headers.getJwsAlgorithm().getName())); + + Set critical = headers.getCritical(); + if (!CollectionUtils.isEmpty(critical)) { + builder.criticalParams(critical); + } + + String contentType = headers.getContentType(); + if (StringUtils.hasText(contentType)) { + builder.contentType(contentType); + } + + String jwkSetUri = headers.getJwkSetUri(); + if (StringUtils.hasText(jwkSetUri)) { + try { + builder.jwkURL(new URI(jwkSetUri)); + } catch (Exception ex) { + throw new JwtEncodingException(String.format( + ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to convert '" + JoseHeaderNames.JKU + "' JOSE header"), ex); + } + } + + Map jwk = headers.getJwk(); + if (!CollectionUtils.isEmpty(jwk)) { + try { + builder.jwk(JWK.parse(new JSONObject(jwk))); + } catch (Exception ex) { + throw new JwtEncodingException(String.format( + ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to convert '" + JoseHeaderNames.JWK + "' JOSE header"), ex); + } + } + + String keyId = headers.getKeyId(); + if (StringUtils.hasText(keyId)) { + builder.keyID(keyId); + } + + String type = headers.getType(); + if (StringUtils.hasText(type)) { + builder.type(new JOSEObjectType(type)); + } + + List x509CertificateChain = headers.getX509CertificateChain(); + if (!CollectionUtils.isEmpty(x509CertificateChain)) { + builder.x509CertChain( + x509CertificateChain.stream() + .map(Base64::new) + .collect(Collectors.toList())); + } + + String x509SHA1Thumbprint = headers.getX509SHA1Thumbprint(); + if (StringUtils.hasText(x509SHA1Thumbprint)) { + builder.x509CertThumbprint(new Base64URL(x509SHA1Thumbprint)); + } + + String x509SHA256Thumbprint = headers.getX509SHA256Thumbprint(); + if (StringUtils.hasText(x509SHA256Thumbprint)) { + builder.x509CertSHA256Thumbprint(new Base64URL(x509SHA256Thumbprint)); + } + + String x509Uri = headers.getX509Uri(); + if (StringUtils.hasText(x509Uri)) { + try { + builder.x509CertURL(new URI(x509Uri)); + } catch (Exception ex) { + throw new JwtEncodingException(String.format( + ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to convert '" + JoseHeaderNames.X5U + "' JOSE header"), ex); + } + } + + Map customHeaders = headers.getHeaders().entrySet().stream() + .filter(header -> !JWSHeader.getRegisteredParameterNames().contains(header.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + if (!CollectionUtils.isEmpty(customHeaders)) { + builder.customParams(customHeaders); + } + + return builder.build(); + } + } + + private static class JwtClaimsSetConverter implements Converter { + + @Override + public JWTClaimsSet convert(JwtClaimsSet claims) { + JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder(); + + URL issuer = claims.getIssuer(); + if (issuer != null) { + builder.issuer(issuer.toExternalForm()); + } + + String subject = claims.getSubject(); + if (StringUtils.hasText(subject)) { + builder.subject(subject); + } + + List audience = claims.getAudience(); + if (!CollectionUtils.isEmpty(audience)) { + builder.audience(audience); + } + + Instant issuedAt = claims.getIssuedAt(); + if (issuedAt != null) { + builder.issueTime(Date.from(issuedAt)); + } + + Instant expiresAt = claims.getExpiresAt(); + if (expiresAt != null) { + builder.expirationTime(Date.from(expiresAt)); + } + + Instant notBefore = claims.getNotBefore(); + if (notBefore != null) { + builder.notBeforeTime(Date.from(notBefore)); + } + + String jwtId = claims.getId(); + if (StringUtils.hasText(jwtId)) { + builder.jwtID(jwtId); + } + + Map customClaims = claims.getClaims().entrySet().stream() + .filter(claim -> !JWTClaimsSet.getRegisteredNames().contains(claim.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + if (!CollectionUtils.isEmpty(customClaims)) { + customClaims.forEach(builder::claim); + } + + return builder.build(); + } + } +} diff --git a/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java b/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java new file mode 100644 index 0000000..d19f2de --- /dev/null +++ b/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java @@ -0,0 +1,198 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +import org.springframework.util.Assert; + +import java.net.URL; +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import static org.springframework.security.oauth2.jwt.JwtClaimNames.AUD; +import static org.springframework.security.oauth2.jwt.JwtClaimNames.EXP; +import static org.springframework.security.oauth2.jwt.JwtClaimNames.IAT; +import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; +import static org.springframework.security.oauth2.jwt.JwtClaimNames.JTI; +import static org.springframework.security.oauth2.jwt.JwtClaimNames.NBF; +import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; + +/** + * The {@link Jwt JWT} Claims Set is a JSON object representing the claims conveyed by a JSON Web Token. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 0.0.1 + * @see Jwt + * @see JwtClaimAccessor + * @see JWT Claims Set + */ +public final class JwtClaimsSet implements JwtClaimAccessor { + private final Map claims; + + private JwtClaimsSet(Map claims) { + this.claims = Collections.unmodifiableMap(new LinkedHashMap<>(claims)); + } + + @Override + public Map getClaims() { + return this.claims; + } + + /** + * Returns a new {@link Builder}. + * + * @return the {@link Builder} + */ + public static Builder withClaims() { + return new Builder(); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@code claims}. + * + * @param claims a JWT claims set + * @return the {@link Builder} + */ + public static Builder from(JwtClaimsSet claims) { + return new Builder(claims); + } + + /** + * A builder for {@link JwtClaimsSet}. + */ + public static class Builder { + private final Map claims = new LinkedHashMap<>(); + + private Builder() { + } + + private Builder(JwtClaimsSet claims) { + Assert.notNull(claims, "claims cannot be null"); + this.claims.putAll(claims.getClaims()); + } + + /** + * Sets the issuer {@code (iss)} claim, which identifies the principal that issued the JWT. + * + * @param issuer the issuer identifier + * @return the {@link Builder} + */ + public Builder issuer(URL issuer) { + return claim(ISS, issuer); + } + + /** + * Sets the subject {@code (sub)} claim, which identifies the principal that is the subject of the JWT. + * + * @param subject the subject identifier + * @return the {@link Builder} + */ + public Builder subject(String subject) { + return claim(SUB, subject); + } + + /** + * Sets the audience {@code (aud)} claim, which identifies the recipient(s) that the JWT is intended for. + * + * @param audience the audience that this JWT is intended for + * @return the {@link Builder} + */ + public Builder audience(List audience) { + return claim(AUD, audience); + } + + /** + * Sets the expiration time {@code (exp)} claim, which identifies the time + * on or after which the JWT MUST NOT be accepted for processing. + * + * @param expiresAt the time on or after which the JWT MUST NOT be accepted for processing + * @return the {@link Builder} + */ + public Builder expiresAt(Instant expiresAt) { + return claim(EXP, expiresAt); + } + + /** + * Sets the not before {@code (nbf)} claim, which identifies the time + * before which the JWT MUST NOT be accepted for processing. + * + * @param notBefore the time before which the JWT MUST NOT be accepted for processing + * @return the {@link Builder} + */ + public Builder notBefore(Instant notBefore) { + return claim(NBF, notBefore); + } + + /** + * Sets the issued at {@code (iat)} claim, which identifies the time at which the JWT was issued. + * + * @param issuedAt the time at which the JWT was issued + * @return the {@link Builder} + */ + public Builder issuedAt(Instant issuedAt) { + return claim(IAT, issuedAt); + } + + /** + * Sets the JWT ID {@code (jti)} claim, which provides a unique identifier for the JWT. + * + * @param jti the unique identifier for the JWT + * @return the {@link Builder} + */ + public Builder id(String jti) { + return claim(JTI, jti); + } + + /** + * Sets the claim. + * + * @param name the claim name + * @param value the claim value + * @return the {@link Builder} + */ + public Builder claim(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.claims.put(name, value); + return this; + } + + /** + * A {@code Consumer} to be provided access to the claims set + * allowing the ability to add, replace, or remove. + * + * @param claimsConsumer a {@code Consumer} of the claims set + */ + public Builder claims(Consumer> claimsConsumer) { + claimsConsumer.accept(this.claims); + return this; + } + + /** + * Builds a new {@link JwtClaimsSet}. + * + * @return a {@link JwtClaimsSet} + */ + public JwtClaimsSet build() { + Assert.notEmpty(this.claims, "claims cannot be empty"); + return new JwtClaimsSet(this.claims); + } + } +} diff --git a/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java b/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java new file mode 100644 index 0000000..846a0df --- /dev/null +++ b/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java @@ -0,0 +1,56 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +import org.springframework.security.oauth2.jose.JoseHeader; + +/** + * Implementations of this interface are responsible for encoding + * a JSON Web Token (JWT) to it's compact claims representation format. + * + *

+ * JWTs may be represented using the JWS Compact Serialization format for a + * JSON Web Signature (JWS) structure or JWE Compact Serialization format for a + * JSON Web Encryption (JWE) structure. Therefore, implementors are responsible + * for signing a JWS and/or encrypting a JWE. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 0.0.1 + * @see Jwt + * @see JoseHeader + * @see JwtClaimsSet + * @see JwtDecoder + * @see JSON Web Token (JWT) + * @see JSON Web Signature (JWS) + * @see JSON Web Encryption (JWE) + * @see JWS Compact Serialization + * @see JWE Compact Serialization + */ +@FunctionalInterface +public interface JwtEncoder { + + /** + * Encode the JWT to it's compact claims representation format. + * + * @param headers the JOSE header + * @param claims the JWT Claims Set + * @return a {@link Jwt} + * @throws JwtEncodingException if an error occurs while attempting to encode the JWT + */ + Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException; + +} diff --git a/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncodingException.java b/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncodingException.java new file mode 100644 index 0000000..4766d9a --- /dev/null +++ b/jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncodingException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +/** + * This exception is thrown when an error occurs + * while attempting to encode a JSON Web Token (JWT). + * + * @author Joe Grandja + * @since 0.0.1 + */ +public class JwtEncodingException extends JwtException { + + /** + * Constructs a {@code JwtEncodingException} using the provided parameters. + * + * @param message the detail message + */ + public JwtEncodingException(String message) { + super(message); + } + + /** + * Constructs a {@code JwtEncodingException} using the provided parameters. + * + * @param message the detail message + * @param cause the root cause + */ + public JwtEncodingException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/jose/src/test/java/org/springframework/security/oauth2/jose/JoseHeaderTests.java b/jose/src/test/java/org/springframework/security/oauth2/jose/JoseHeaderTests.java new file mode 100644 index 0000000..dbc16e8 --- /dev/null +++ b/jose/src/test/java/org/springframework/security/oauth2/jose/JoseHeaderTests.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jose; + +import org.junit.Test; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link JoseHeader}. + * + * @author Joe Grandja + */ +public class JoseHeaderTests { + + @Test + public void withAlgorithmWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JoseHeader.withAlgorithm(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwsAlgorithm cannot be null"); + } + + @Test + public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { + JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build(); + + JoseHeader joseHeader = JoseHeader.withAlgorithm(expectedJoseHeader.getJwsAlgorithm()) + .jwkSetUri(expectedJoseHeader.getJwkSetUri()) + .jwk(expectedJoseHeader.getJwk()) + .keyId(expectedJoseHeader.getKeyId()) + .x509Uri(expectedJoseHeader.getX509Uri()) + .x509CertificateChain(expectedJoseHeader.getX509CertificateChain()) + .x509SHA1Thumbprint(expectedJoseHeader.getX509SHA1Thumbprint()) + .x509SHA256Thumbprint(expectedJoseHeader.getX509SHA256Thumbprint()) + .critical(expectedJoseHeader.getCritical()) + .type(expectedJoseHeader.getType()) + .contentType(expectedJoseHeader.getContentType()) + .headers(headers -> headers.put("custom-header-name", "custom-header-value")) + .build(); + + assertThat(joseHeader.getJwsAlgorithm()).isEqualTo(expectedJoseHeader.getJwsAlgorithm()); + assertThat(joseHeader.getJwkSetUri()).isEqualTo(expectedJoseHeader.getJwkSetUri()); + assertThat(joseHeader.getJwk()).isEqualTo(expectedJoseHeader.getJwk()); + assertThat(joseHeader.getKeyId()).isEqualTo(expectedJoseHeader.getKeyId()); + assertThat(joseHeader.getX509Uri()).isEqualTo(expectedJoseHeader.getX509Uri()); + assertThat(joseHeader.getX509CertificateChain()).isEqualTo(expectedJoseHeader.getX509CertificateChain()); + assertThat(joseHeader.getX509SHA1Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA1Thumbprint()); + assertThat(joseHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA256Thumbprint()); + assertThat(joseHeader.getCritical()).isEqualTo(expectedJoseHeader.getCritical()); + assertThat(joseHeader.getType()).isEqualTo(expectedJoseHeader.getType()); + assertThat(joseHeader.getContentType()).isEqualTo(expectedJoseHeader.getContentType()); + assertThat(joseHeader.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); + } + + @Test + public void fromWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JoseHeader.from(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("headers cannot be null"); + } + + @Test + public void fromWhenHeadersProvidedThenCopied() { + JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build(); + JoseHeader joseHeader = JoseHeader.from(expectedJoseHeader).build(); + assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); + } + + @Test + public void headerWhenNameNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header(null, "value")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be empty"); + } + + @Test + public void headerWhenValueNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header("name", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + + @Test + public void getHeaderWhenNullThenThrowIllegalArgumentException() { + JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); + + assertThatThrownBy(() -> joseHeader.getHeader(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be empty"); + } +} diff --git a/jose/src/test/java/org/springframework/security/oauth2/jose/TestJoseHeaders.java b/jose/src/test/java/org/springframework/security/oauth2/jose/TestJoseHeaders.java new file mode 100644 index 0000000..70f3c15 --- /dev/null +++ b/jose/src/test/java/org/springframework/security/oauth2/jose/TestJoseHeaders.java @@ -0,0 +1,57 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jose; + +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +/** + * @author Joe Grandja + */ +public class TestJoseHeaders { + + public static JoseHeader.Builder joseHeader() { + return joseHeader(SignatureAlgorithm.RS256); + } + + public static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) { + return JoseHeader.withAlgorithm(signatureAlgorithm) + .jwkSetUri("https://provider.com/oauth2/jwks") + .jwk(rsaJwk()) + .keyId(UUID.randomUUID().toString()) + .x509Uri("https://provider.com/oauth2/x509") + .x509CertificateChain(Arrays.asList("x509Cert1", "x509Cert2")) + .x509SHA1Thumbprint("x509SHA1Thumbprint") + .x509SHA256Thumbprint("x509SHA256Thumbprint") + .critical(Collections.singleton("custom-header-name")) + .type("JWT") + .contentType("jwt-content-type") + .header("custom-header-name", "custom-header-value"); + } + + private static Map rsaJwk() { + Map rsaJwk = new HashMap<>(); + rsaJwk.put("kty", "RSA"); + rsaJwk.put("n", "modulus"); + rsaJwk.put("e", "exponent"); + return rsaJwk; + } +} diff --git a/jose/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java b/jose/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java new file mode 100644 index 0000000..2fc2235 --- /dev/null +++ b/jose/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java @@ -0,0 +1,159 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jose.jws; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.ManagedKey; +import org.springframework.security.crypto.keys.TestManagedKeys; +import org.springframework.security.oauth2.jose.JoseHeader; +import org.springframework.security.oauth2.jose.JoseHeaderNames; +import org.springframework.security.oauth2.jose.TestJoseHeaders; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncodingException; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; +import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; + +import java.security.interfaces.RSAPublicKey; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link NimbusJwsEncoder}. + * + * @author Joe Grandja + */ +public class NimbusJwsEncoderTests { + private KeyManager keyManager; + private NimbusJwsEncoder jwtEncoder; + + @Before + public void setUp() { + this.keyManager = mock(KeyManager.class); + this.jwtEncoder = new NimbusJwsEncoder(this.keyManager); + } + + @Test + public void constructorWhenKeyManagerNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new NimbusJwsEncoder(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("keyManager cannot be null"); + } + + @Test + public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatThrownBy(() -> this.jwtEncoder.encode(null, jwtClaimsSet)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("headers cannot be null"); + } + + @Test + public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { + JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); + + assertThatThrownBy(() -> this.jwtEncoder.encode(joseHeader, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("claims cannot be null"); + } + + @Test + public void encodeWhenUnsupportedKeyThenThrowJwtEncodingException() { + JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatThrownBy(() -> this.jwtEncoder.encode(joseHeader, jwtClaimsSet)) + .isInstanceOf(JwtEncodingException.class) + .hasMessageContaining("Unsupported key for algorithm 'RS256'"); + } + + @Test + public void encodeWhenUnsupportedKeyAlgorithmThenThrowJwtEncodingException() { + JoseHeader joseHeader = TestJoseHeaders.joseHeader(SignatureAlgorithm.ES256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatThrownBy(() -> this.jwtEncoder.encode(joseHeader, jwtClaimsSet)) + .isInstanceOf(JwtEncodingException.class) + .hasMessageContaining("Unsupported key for algorithm 'ES256'"); + } + + @Test + public void encodeWhenUnsupportedKeyTypeThenThrowJwtEncodingException() { + ManagedKey managedKey = TestManagedKeys.ecManagedKey().build(); + when(this.keyManager.findByAlgorithm(any())).thenReturn(Collections.singleton(managedKey)); + + JoseHeader joseHeader = TestJoseHeaders.joseHeader(SignatureAlgorithm.ES256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + assertThatThrownBy(() -> this.jwtEncoder.encode(joseHeader, jwtClaimsSet)) + .isInstanceOf(JwtEncodingException.class) + .hasMessageContaining("Unsupported key type 'EC'"); + } + + @Test + public void encodeWhenSuccessThenDecodes() { + ManagedKey managedKey = TestManagedKeys.rsaManagedKey().build(); + when(this.keyManager.findByAlgorithm(any())).thenReturn(Collections.singleton(managedKey)); + + JoseHeader joseHeader = TestJoseHeaders.joseHeader() + .headers(headers -> headers.remove(JoseHeaderNames.CRIT)) + .build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt jws = this.jwtEncoder.encode(joseHeader, jwtClaimsSet); + + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey((RSAPublicKey) managedKey.getPublicKey()).build(); + jwtDecoder.decode(jws.getTokenValue()); + } + + @Test + public void encodeWhenMultipleActiveKeysThenUseMostRecent() { + ManagedKey managedKeyActivated2DaysAgo = TestManagedKeys.rsaManagedKey() + .activatedOn(Instant.now().minus(2, ChronoUnit.DAYS)) + .build(); + ManagedKey managedKeyActivated1DayAgo = TestManagedKeys.rsaManagedKey() + .activatedOn(Instant.now().minus(1, ChronoUnit.DAYS)) + .build(); + ManagedKey managedKeyActivatedToday = TestManagedKeys.rsaManagedKey() + .activatedOn(Instant.now()) + .build(); + + when(this.keyManager.findByAlgorithm(any())).thenReturn( + Stream.of(managedKeyActivated2DaysAgo, managedKeyActivated1DayAgo, managedKeyActivatedToday) + .collect(Collectors.toSet())); + + JoseHeader joseHeader = TestJoseHeaders.joseHeader() + .headers(headers -> headers.remove(JoseHeaderNames.CRIT)) + .build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt jws = this.jwtEncoder.encode(joseHeader, jwtClaimsSet); + + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey((RSAPublicKey) managedKeyActivatedToday.getPublicKey()).build(); + jwtDecoder.decode(jws.getTokenValue()); + } +} diff --git a/jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java b/jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java new file mode 100644 index 0000000..12340c0 --- /dev/null +++ b/jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link JwtClaimsSet}. + * + * @author Joe Grandja + */ +public class JwtClaimsSetTests { + + @Test + public void buildWhenClaimsEmptyThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JwtClaimsSet.withClaims().build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("claims cannot be empty"); + } + + @Test + public void buildWhenAllClaimsProvidedThenAllClaimsAreSet() { + JwtClaimsSet expectedJwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() + .issuer(expectedJwtClaimsSet.getIssuer()) + .subject(expectedJwtClaimsSet.getSubject()) + .audience(expectedJwtClaimsSet.getAudience()) + .issuedAt(expectedJwtClaimsSet.getIssuedAt()) + .notBefore(expectedJwtClaimsSet.getNotBefore()) + .expiresAt(expectedJwtClaimsSet.getExpiresAt()) + .id(expectedJwtClaimsSet.getId()) + .claims(claims -> claims.put("custom-claim-name", "custom-claim-value")) + .build(); + + assertThat(jwtClaimsSet.getIssuer()).isEqualTo(expectedJwtClaimsSet.getIssuer()); + assertThat(jwtClaimsSet.getSubject()).isEqualTo(expectedJwtClaimsSet.getSubject()); + assertThat(jwtClaimsSet.getAudience()).isEqualTo(expectedJwtClaimsSet.getAudience()); + assertThat(jwtClaimsSet.getIssuedAt()).isEqualTo(expectedJwtClaimsSet.getIssuedAt()); + assertThat(jwtClaimsSet.getNotBefore()).isEqualTo(expectedJwtClaimsSet.getNotBefore()); + assertThat(jwtClaimsSet.getExpiresAt()).isEqualTo(expectedJwtClaimsSet.getExpiresAt()); + assertThat(jwtClaimsSet.getId()).isEqualTo(expectedJwtClaimsSet.getId()); + assertThat(jwtClaimsSet.getClaim("custom-claim-name")).isEqualTo("custom-claim-value"); + assertThat(jwtClaimsSet.getClaims()).isEqualTo(expectedJwtClaimsSet.getClaims()); + } + + @Test + public void fromWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JwtClaimsSet.from(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("claims cannot be null"); + } + + @Test + public void fromWhenClaimsProvidedThenCopied() { + JwtClaimsSet expectedJwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.from(expectedJwtClaimsSet).build(); + assertThat(jwtClaimsSet.getClaims()).isEqualTo(expectedJwtClaimsSet.getClaims()); + } + + @Test + public void claimWhenNameNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JwtClaimsSet.withClaims().claim(null, "value")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be empty"); + } + + @Test + public void claimWhenValueNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> JwtClaimsSet.withClaims().claim("name", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } +} diff --git a/jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwtClaimsSets.java b/jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwtClaimsSets.java new file mode 100644 index 0000000..42cb13e --- /dev/null +++ b/jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwtClaimsSets.java @@ -0,0 +1,50 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.UUID; + +/** + * @author Joe Grandja + */ +public class TestJwtClaimsSets { + + public static JwtClaimsSet.Builder jwtClaimsSet() { + URL issuer = null; + try { + issuer = URI.create("https://provider.com").toURL(); + } catch (MalformedURLException e) { } + + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); + + return JwtClaimsSet.withClaims() + .issuer(issuer) + .subject("subject") + .audience(Collections.singletonList("client-1")) + .issuedAt(issuedAt) + .notBefore(issuedAt) + .expiresAt(expiresAt) + .id(UUID.randomUUID().toString()) + .claim("custom-claim-name", "custom-claim-value"); + } +} diff --git a/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle b/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle index 56947d7..b952955 100644 --- a/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle +++ b/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle @@ -2,6 +2,7 @@ apply plugin: 'io.spring.convention.spring-module' dependencies { compile project(':spring-security-core2') + compile project(':spring-security-oauth2-jose2') compile 'org.springframework.security:spring-security-core' compile 'org.springframework.security:spring-security-web' compile 'org.springframework.security:spring-security-oauth2-core' diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java index 09b45b9..d691a6e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -39,4 +40,9 @@ public interface OAuth2AuthorizationAttributeNames { */ String AUTHORIZATION_REQUEST = OAuth2Authorization.class.getName().concat(".AUTHORIZATION_REQUEST"); + /** + * The name of the attribute used for the attributes/claims of the {@link OAuth2AccessToken}. + */ + String ACCESS_TOKEN_ATTRIBUTES = OAuth2Authorization.class.getName().concat(".ACCESS_TOKEN_ATTRIBUTES"); + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index eb87295..6646b17 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -18,13 +18,17 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; -import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.JoseHeader; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -33,9 +37,12 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.Base64; +import java.util.Collections; /** * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant. @@ -46,26 +53,30 @@ import java.util.Base64; * @see OAuth2AccessTokenAuthenticationToken * @see RegisteredClientRepository * @see OAuth2AuthorizationService + * @see JwtEncoder * @see Section 4.1 Authorization Code Grant * @see Section 4.1.3 Access Token Request */ public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { private final RegisteredClientRepository registeredClientRepository; private final OAuth2AuthorizationService authorizationService; - private final StringKeyGenerator accessTokenGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + private final JwtEncoder jwtEncoder; /** * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. * * @param registeredClientRepository the repository of registered clients * @param authorizationService the authorization service + * @param jwtEncoder the jwt encoder */ public OAuth2AuthorizationCodeAuthenticationProvider(RegisteredClientRepository registeredClientRepository, - OAuth2AuthorizationService authorizationService) { + OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) { Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; + this.jwtEncoder = jwtEncoder; } @Override @@ -105,13 +116,34 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); } - String tokenValue = this.accessTokenGenerator.generateKey(); + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + + // TODO Allow configuration for issuer claim + URL issuer = null; + try { + issuer = URI.create("https://oauth2.provider.com").toURL(); + } catch (MalformedURLException e) { } + Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token lifespan + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token time-to-live + + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() + .issuer(issuer) + .subject(authorization.getPrincipalName()) + .audience(Collections.singletonList(clientPrincipal.getRegisteredClient().getClientId())) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .notBefore(issuedAt) + .claim(OAuth2ParameterNames.SCOPE, authorizationRequest.getScopes()) + .build(); + + Jwt jwt = this.jwtEncoder.encode(joseHeader, jwtClaimsSet); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - tokenValue, issuedAt, expiresAt, authorizationRequest.getScopes()); + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); authorization = OAuth2Authorization.from(authorization) + .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .accessToken(accessToken) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index 29d0fc2..622475b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -18,21 +18,29 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; -import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.JoseHeader; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.Base64; +import java.util.Collections; import java.util.LinkedHashSet; import java.util.Set; import java.util.stream.Collectors; @@ -45,21 +53,26 @@ import java.util.stream.Collectors; * @see OAuth2ClientCredentialsAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AuthorizationService + * @see JwtEncoder * @see Section 4.4 Client Credentials Grant * @see Section 4.4.2 Access Token Request */ public class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider { private final OAuth2AuthorizationService authorizationService; - private final StringKeyGenerator accessTokenGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + private final JwtEncoder jwtEncoder; /** * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters. * * @param authorizationService the authorization service + * @param jwtEncoder the jwt encoder */ - public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService authorizationService) { + public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService authorizationService, + JwtEncoder jwtEncoder) { Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); this.authorizationService = authorizationService; + this.jwtEncoder = jwtEncoder; } @Override @@ -87,13 +100,34 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes()); } - String tokenValue = this.accessTokenGenerator.generateKey(); + JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + + // TODO Allow configuration for issuer claim + URL issuer = null; + try { + issuer = URI.create("https://oauth2.provider.com").toURL(); + } catch (MalformedURLException e) { } + Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token lifespan + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token time-to-live + + JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims() + .issuer(issuer) + .subject(clientPrincipal.getName()) + .audience(Collections.singletonList(registeredClient.getClientId())) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .notBefore(issuedAt) + .claim(OAuth2ParameterNames.SCOPE, scopes) + .build(); + + Jwt jwt = this.jwtEncoder.encode(joseHeader, jwtClaimsSet); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - tokenValue, issuedAt, expiresAt, scopes); + jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) + .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .principalName(clientPrincipal.getName()) .accessToken(accessToken) .build(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 78ca0b3..52c5592 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -22,6 +22,10 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.jose.JoseHeaderNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -32,8 +36,12 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -48,6 +56,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { private RegisteredClient registeredClient; private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; + private JwtEncoder jwtEncoder; private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; @Before @@ -55,24 +64,32 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { this.registeredClient = TestRegisteredClients.registeredClient().build(); this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient); this.authorizationService = mock(OAuth2AuthorizationService.class); + this.jwtEncoder = mock(JwtEncoder.class); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( - this.registeredClientRepository, this.authorizationService); + this.registeredClientRepository, this.authorizationService, this.jwtEncoder); } @Test public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null, this.authorizationService)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null, this.authorizationService, this.jwtEncoder)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("registeredClientRepository cannot be null"); } @Test public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.registeredClientRepository, null)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.registeredClientRepository, null, this.jwtEncoder)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizationService cannot be null"); } + @Test + public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.registeredClientRepository, this.authorizationService, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtEncoder cannot be null"); + } + @Test public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue(); @@ -163,6 +180,15 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri()); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); + Jwt jwt = Jwt.withTokenValue("token") + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .build(); + when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index 34990f4..b1e6368 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -21,18 +21,26 @@ import org.mockito.ArgumentCaptor; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.jose.JoseHeaderNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Tests for {@link OAuth2ClientCredentialsAuthenticationProvider}. @@ -43,22 +51,32 @@ import static org.mockito.Mockito.verify; public class OAuth2ClientCredentialsAuthenticationProviderTests { private RegisteredClient registeredClient; private OAuth2AuthorizationService authorizationService; + private JwtEncoder jwtEncoder; private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider; @Before public void setUp() { this.registeredClient = TestRegisteredClients.registeredClient().build(); this.authorizationService = mock(OAuth2AuthorizationService.class); - this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService); + this.jwtEncoder = mock(JwtEncoder.class); + this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider( + this.authorizationService, this.jwtEncoder); } @Test public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(null)) + assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(null, this.jwtEncoder)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizationService cannot be null"); } + @Test + public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtEncoder cannot be null"); + } + @Test public void supportsWhenSupportedAuthenticationThenTrue() { assertThat(this.authenticationProvider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isTrue(); @@ -115,6 +133,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(requestedScope); @@ -125,6 +145,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal); + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -139,4 +161,14 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getAccessToken()); } + + private static Jwt createJwt() { + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); + return Jwt.withTokenValue("token") + .header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName()) + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .build(); + } }