diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilter.java index 7f33597..09fd2db 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilter.java @@ -73,7 +73,7 @@ public class NimbusJwkSetEndpointFilter extends OncePerRequestFilter { Assert.notNull(jwkSource, "jwkSource cannot be null"); Assert.hasText(jwkSetEndpointUri, "jwkSetEndpointUri cannot be empty"); this.jwkSource = jwkSource; - this.jwkSelector = new JWKSelector(new JWKMatcher.Builder().publicOnly(true).build()); + this.jwkSelector = new JWKSelector(new JWKMatcher.Builder().build()); this.requestMatcher = new AntPathRequestMatcher(jwkSetEndpointUri, HttpMethod.GET.name()); } @@ -91,12 +91,12 @@ public class NimbusJwkSetEndpointFilter extends OncePerRequestFilter { jwkSet = new JWKSet(this.jwkSource.get(this.jwkSelector, null)); } catch (Exception ex) { - throw new IllegalStateException("Failed to select the JWK public key(s) -> " + ex.getMessage(), ex); + throw new IllegalStateException("Failed to select the JWK(s) -> " + ex.getMessage(), ex); } response.setContentType(MediaType.APPLICATION_JSON_VALUE); try (Writer writer = response.getWriter()) { - writer.write(jwkSet.toString()); + writer.write(jwkSet.toString()); // toString() excludes private keys } } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilterTests.java index ff408d4..fce8e74 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilterTests.java @@ -15,14 +15,15 @@ */ package org.springframework.security.oauth2.server.authorization.web; -import java.util.Arrays; -import java.util.Collections; +import java.util.ArrayList; +import java.util.List; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.OctetSequenceKey; @@ -40,7 +41,6 @@ import org.springframework.security.oauth2.jose.TestJwks; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -51,12 +51,14 @@ import static org.mockito.Mockito.verifyNoInteractions; * @author Joe Grandja */ public class NimbusJwkSetEndpointFilterTests { + private List jwkList; private JWKSource jwkSource; private NimbusJwkSetEndpointFilter filter; @Before public void setUp() { - this.jwkSource = mock(JWKSource.class); + this.jwkList = new ArrayList<>(); + this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); this.filter = new NimbusJwkSetEndpointFilter(this.jwkSource); } @@ -103,8 +105,9 @@ public class NimbusJwkSetEndpointFilterTests { @Test public void doFilterWhenAsymmetricKeysThenJwkSetResponse() throws Exception { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + this.jwkList.add(rsaJwk); ECKey ecJwk = TestJwks.DEFAULT_EC_JWK; - given(this.jwkSource.get(any(), any())).willReturn(Arrays.asList(rsaJwk, ecJwk)); + this.jwkList.add(ecJwk); String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); @@ -137,7 +140,7 @@ public class NimbusJwkSetEndpointFilterTests { @Test public void doFilterWhenSymmetricKeysThenJwkSetResponseEmpty() throws Exception { OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK; - given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(secretJwk)); + this.jwkList.add(secretJwk); String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);