From c30488cd3045bb4aaa42d8ca88b6161b5dea2f66 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 17 Aug 2020 06:34:42 -0400 Subject: [PATCH] Add JWK Set Endpoint Filter Closes gh-82 --- .../OAuth2AuthorizationServerConfigurer.java | 4 + ...ecurity-oauth2-authorization-server.gradle | 1 + .../web/JwkSetEndpointFilter.java | 132 +++++++++++++ .../web/JwkSetEndpointFilterTests.java | 186 ++++++++++++++++++ 4 files changed, 323 insertions(+) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/JwkSetEndpointFilter.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/JwkSetEndpointFilterTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index 3ceb350..2201448 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -32,6 +32,7 @@ import org.springframework.security.oauth2.server.authorization.authentication.O import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.web.JwkSetEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientAuthenticationFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; @@ -130,6 +131,9 @@ public final class OAuth2AuthorizationServerConfigurerJSON Web Key (JWK) + * @see Section 5 JWK Set Format + */ +public class JwkSetEndpointFilter extends OncePerRequestFilter { + /** + * The default endpoint {@code URI} for JWK Set requests. + */ + public static final String DEFAULT_JWK_SET_ENDPOINT_URI = "/oauth2/jwks"; + + private final KeyManager keyManager; + private final RequestMatcher requestMatcher; + + /** + * Constructs a {@code JwkSetEndpointFilter} using the provided parameters. + * + * @param keyManager the key manager + */ + public JwkSetEndpointFilter(KeyManager keyManager) { + this(keyManager, DEFAULT_JWK_SET_ENDPOINT_URI); + } + + /** + * Constructs a {@code JwkSetEndpointFilter} using the provided parameters. + * + * @param keyManager the key manager + * @param jwkSetEndpointUri the endpoint {@code URI} for JWK Set requests + */ + public JwkSetEndpointFilter(KeyManager keyManager, String jwkSetEndpointUri) { + Assert.notNull(keyManager, "keyManager cannot be null"); + Assert.hasText(jwkSetEndpointUri, "jwkSetEndpointUri cannot be empty"); + this.keyManager = keyManager; + this.requestMatcher = new AntPathRequestMatcher(jwkSetEndpointUri, HttpMethod.GET.name()); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (!this.requestMatcher.matches(request)) { + filterChain.doFilter(request, response); + return; + } + + JWKSet jwkSet = buildJwkSet(); + + response.setContentType(MediaType.APPLICATION_JSON_VALUE); + try (Writer writer = response.getWriter()) { + writer.write(jwkSet.toJSONObject().toString()); + } + } + + private JWKSet buildJwkSet() { + return new JWKSet( + this.keyManager.getKeys().stream() + .filter(managedKey -> managedKey.isActive() && managedKey.isAsymmetric()) + .map(this::convert) + .filter(Objects::nonNull) + .collect(Collectors.toList()) + ); + } + + private JWK convert(ManagedKey managedKey) { + JWK jwk = null; + if (managedKey.getPublicKey() instanceof RSAPublicKey) { + RSAPublicKey publicKey = (RSAPublicKey) managedKey.getPublicKey(); + jwk = new RSAKey.Builder(publicKey) + .keyUse(KeyUse.SIGNATURE) + .algorithm(JWSAlgorithm.RS256) + .keyID(managedKey.getKeyId()) + .build(); + } else if (managedKey.getPublicKey() instanceof ECPublicKey) { + ECPublicKey publicKey = (ECPublicKey) managedKey.getPublicKey(); + Curve curve = Curve.forECParameterSpec(publicKey.getParams()); + jwk = new ECKey.Builder(curve, publicKey) + .keyUse(KeyUse.SIGNATURE) + .algorithm(JWSAlgorithm.ES256) + .keyID(managedKey.getKeyId()) + .build(); + } + return jwk; + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/JwkSetEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/JwkSetEndpointFilterTests.java new file mode 100644 index 0000000..cf427ab --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/JwkSetEndpointFilterTests.java @@ -0,0 +1,186 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.RSAKey; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.crypto.keys.KeyManager; +import org.springframework.security.crypto.keys.ManagedKey; +import org.springframework.security.crypto.keys.TestManagedKeys; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.time.Instant; +import java.util.Collections; +import java.util.HashSet; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link JwkSetEndpointFilter}. + * + * @author Joe Grandja + */ +public class JwkSetEndpointFilterTests { + private KeyManager keyManager; + private JwkSetEndpointFilter filter; + + @Before + public void setUp() { + this.keyManager = mock(KeyManager.class); + this.filter = new JwkSetEndpointFilter(this.keyManager); + } + + @Test + public void constructorWhenKeyManagerNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new JwkSetEndpointFilter(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("keyManager cannot be null"); + } + + @Test + public void constructorWhenJwkSetEndpointUriNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new JwkSetEndpointFilter(this.keyManager, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwkSetEndpointUri cannot be empty"); + } + + @Test + public void doFilterWhenNotJwkSetRequestThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenJwkSetRequestPostThenNotProcessed() throws Exception { + String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAsymmetricKeysThenJwkSetResponse() throws Exception { + ManagedKey rsaManagedKey = TestManagedKeys.rsaManagedKey().build(); + ManagedKey ecManagedKey = TestManagedKeys.ecManagedKey().build(); + when(this.keyManager.getKeys()).thenReturn( + Stream.of(rsaManagedKey, ecManagedKey).collect(Collectors.toSet())); + + String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + + JWKSet jwkSet = JWKSet.parse(response.getContentAsString()); + assertThat(jwkSet.getKeys()).hasSize(2); + + RSAKey rsaJwk = (RSAKey) jwkSet.getKeyByKeyId(rsaManagedKey.getKeyId()); + assertThat(rsaJwk).isNotNull(); + assertThat(rsaJwk.toRSAPublicKey()).isEqualTo(rsaManagedKey.getPublicKey()); + assertThat(rsaJwk.toRSAPrivateKey()).isNull(); + assertThat(rsaJwk.getKeyUse()).isEqualTo(KeyUse.SIGNATURE); + assertThat(rsaJwk.getAlgorithm()).isEqualTo(JWSAlgorithm.RS256); + + ECKey ecJwk = (ECKey) jwkSet.getKeyByKeyId(ecManagedKey.getKeyId()); + assertThat(ecJwk).isNotNull(); + assertThat(ecJwk.toECPublicKey()).isEqualTo(ecManagedKey.getPublicKey()); + assertThat(ecJwk.toECPublicKey()).isEqualTo(ecManagedKey.getPublicKey()); + assertThat(ecJwk.toECPrivateKey()).isNull(); + assertThat(ecJwk.getKeyUse()).isEqualTo(KeyUse.SIGNATURE); + assertThat(ecJwk.getAlgorithm()).isEqualTo(JWSAlgorithm.ES256); + } + + @Test + public void doFilterWhenSymmetricKeysThenJwkSetResponseEmpty() throws Exception { + ManagedKey secretManagedKey = TestManagedKeys.secretManagedKey().build(); + when(this.keyManager.getKeys()).thenReturn( + new HashSet<>(Collections.singleton(secretManagedKey))); + + String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + + JWKSet jwkSet = JWKSet.parse(response.getContentAsString()); + assertThat(jwkSet.getKeys()).isEmpty(); + } + + @Test + public void doFilterWhenNoActiveKeysThenJwkSetResponseEmpty() throws Exception { + ManagedKey rsaManagedKey = TestManagedKeys.rsaManagedKey().deactivatedOn(Instant.now()).build(); + ManagedKey ecManagedKey = TestManagedKeys.ecManagedKey().deactivatedOn(Instant.now()).build(); + when(this.keyManager.getKeys()).thenReturn( + Stream.of(rsaManagedKey, ecManagedKey).collect(Collectors.toSet())); + + String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + + JWKSet jwkSet = JWKSet.parse(response.getContentAsString()); + assertThat(jwkSet.getKeys()).isEmpty(); + } +}