/* * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.oauth2.jwt; import java.security.interfaces.ECPrivateKey; import java.security.interfaces.ECPublicKey; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSelector; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.OctetSequenceKey; import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; import org.junit.Before; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; /** * Tests for {@link NimbusJwsEncoder}. * * @author Joe Grandja */ public class NimbusJwsEncoderTests { private List jwkList; private JWKSource jwkSource; private NimbusJwsEncoder jwsEncoder; @Before public void setUp() { this.jwkList = new ArrayList<>(); this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); } @Test public void constructorWhenJwkSourceNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwsEncoder(null)) .withMessage("jwkSource cannot be null"); } @Test public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(null, jwtClaimsSet)) .withMessage("headers cannot be null"); } @Test public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(joseHeader, null)) .withMessage("claims cannot be null"); } @Test public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception { this.jwkSource = mock(JWKSource.class); this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willThrow(new KeySourceException("key source error")); JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) .withMessageContaining("Failed to select a JWK signing key -> key source error"); } @Test public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; this.jwkList.add(rsaJwk); this.jwkList.add(rsaJwk); JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); } @Test public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) .withMessageContaining("Failed to select a JWK signing key"); } @Test public void encodeWhenJwkKidNullThenThrowJwtEncodingException() throws Exception { // @formatter:off RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) .keyID(null) .build(); // @formatter:on this.jwkList.add(rsaJwk); JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) .withMessageContaining("The \"kid\" (key ID) from the selected JWK cannot be empty"); } @Test public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exception { // @formatter:off RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) .keyUse(KeyUse.ENCRYPTION) .build(); // @formatter:on this.jwkSource = mock(JWKSource.class); this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)).withMessageContaining( "Failed to create a JWS Signer -> The JWK use must be sig (signature) or unspecified"); } @Test public void encodeWhenSuccessThenDecodes() throws Exception { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; this.jwkList.add(rsaJwk); JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet); // Assert headers/claims were added assertThat(encodedJws.getHeaders().get(JoseHeaderNames.TYP)).isEqualTo("JWT"); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID()); assertThat(encodedJws.getId()).isNotNull(); NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk.toRSAPublicKey()).build(); jwtDecoder.decode(encodedJws.getTokenValue()); } @Test public void encodeWhenKeysRotatedThenNewKeyUsed() throws Exception { TestJWKSource jwkSource = new TestJWKSource(); JWKSource jwkSourceDelegate = spy(new JWKSource() { @Override public List get(JWKSelector jwkSelector, SecurityContext context) { return jwkSource.get(jwkSelector, context); } }); NimbusJwsEncoder jwsEncoder = new NimbusJwsEncoder(jwkSourceDelegate); JwkListResultCaptor jwkListResultCaptor = new JwkListResultCaptor(); willAnswer(jwkListResultCaptor).given(jwkSourceDelegate).get(any(), any()); JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); Jwt encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet); JWK jwk1 = jwkListResultCaptor.getResult().get(0); NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk1).toRSAPublicKey()).build(); jwtDecoder.decode(encodedJws.getTokenValue()); jwkSource.rotate(); // Trigger key rotation encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet); JWK jwk2 = jwkListResultCaptor.getResult().get(0); jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk2).toRSAPublicKey()).build(); jwtDecoder.decode(encodedJws.getTokenValue()); assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID()); } private static final class JwkListResultCaptor implements Answer> { private List result; private List getResult() { return this.result; } @SuppressWarnings("unchecked") @Override public List answer(InvocationOnMock invocationOnMock) throws Throwable { this.result = (List) invocationOnMock.callRealMethod(); return this.result; } } private static final class TestJWKSource implements JWKSource { private int keyId = 1000; private JWKSet jwkSet; private TestJWKSource() { init(); } @Override public List get(JWKSelector jwkSelector, SecurityContext context) { return jwkSelector.select(this.jwkSet); } private void init() { // @formatter:off RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) .keyID("rsa-jwk-" + this.keyId++) .build(); ECKey ecJwk = TestJwks.jwk((ECPublicKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPublic(), (ECPrivateKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPrivate()) .keyID("ec-jwk-" + this.keyId++) .build(); OctetSequenceKey secretJwk = TestJwks.jwk(TestKeys.DEFAULT_SECRET_KEY) .keyID("secret-jwk-" + this.keyId++) .build(); // @formatter:on this.jwkSet = new JWKSet(Arrays.asList(rsaJwk, ecJwk, secretJwk)); } private void rotate() { init(); } } }