Fix NimbusJwkSetEndpointFilter

Closes gh-198
This commit is contained in:
Joe Grandja 2021-01-18 19:54:58 -05:00
parent 12f4001c9d
commit b7996e26d0
2 changed files with 12 additions and 9 deletions

View File

@ -73,7 +73,7 @@ public class NimbusJwkSetEndpointFilter extends OncePerRequestFilter {
Assert.notNull(jwkSource, "jwkSource cannot be null"); Assert.notNull(jwkSource, "jwkSource cannot be null");
Assert.hasText(jwkSetEndpointUri, "jwkSetEndpointUri cannot be empty"); Assert.hasText(jwkSetEndpointUri, "jwkSetEndpointUri cannot be empty");
this.jwkSource = jwkSource; this.jwkSource = jwkSource;
this.jwkSelector = new JWKSelector(new JWKMatcher.Builder().publicOnly(true).build()); this.jwkSelector = new JWKSelector(new JWKMatcher.Builder().build());
this.requestMatcher = new AntPathRequestMatcher(jwkSetEndpointUri, HttpMethod.GET.name()); this.requestMatcher = new AntPathRequestMatcher(jwkSetEndpointUri, HttpMethod.GET.name());
} }
@ -91,12 +91,12 @@ public class NimbusJwkSetEndpointFilter extends OncePerRequestFilter {
jwkSet = new JWKSet(this.jwkSource.get(this.jwkSelector, null)); jwkSet = new JWKSet(this.jwkSource.get(this.jwkSelector, null));
} }
catch (Exception ex) { catch (Exception ex) {
throw new IllegalStateException("Failed to select the JWK public key(s) -> " + ex.getMessage(), ex); throw new IllegalStateException("Failed to select the JWK(s) -> " + ex.getMessage(), ex);
} }
response.setContentType(MediaType.APPLICATION_JSON_VALUE); response.setContentType(MediaType.APPLICATION_JSON_VALUE);
try (Writer writer = response.getWriter()) { try (Writer writer = response.getWriter()) {
writer.write(jwkSet.toString()); writer.write(jwkSet.toString()); // toString() excludes private keys
} }
} }
} }

View File

@ -15,14 +15,15 @@
*/ */
package org.springframework.security.oauth2.server.authorization.web; package org.springframework.security.oauth2.server.authorization.web;
import java.util.Arrays; import java.util.ArrayList;
import java.util.Collections; import java.util.List;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey; import com.nimbusds.jose.jwk.OctetSequenceKey;
@ -40,7 +41,6 @@ import org.springframework.security.oauth2.jose.TestJwks;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
@ -51,12 +51,14 @@ import static org.mockito.Mockito.verifyNoInteractions;
* @author Joe Grandja * @author Joe Grandja
*/ */
public class NimbusJwkSetEndpointFilterTests { public class NimbusJwkSetEndpointFilterTests {
private List<JWK> jwkList;
private JWKSource<SecurityContext> jwkSource; private JWKSource<SecurityContext> jwkSource;
private NimbusJwkSetEndpointFilter filter; private NimbusJwkSetEndpointFilter filter;
@Before @Before
public void setUp() { public void setUp() {
this.jwkSource = mock(JWKSource.class); this.jwkList = new ArrayList<>();
this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList));
this.filter = new NimbusJwkSetEndpointFilter(this.jwkSource); this.filter = new NimbusJwkSetEndpointFilter(this.jwkSource);
} }
@ -103,8 +105,9 @@ public class NimbusJwkSetEndpointFilterTests {
@Test @Test
public void doFilterWhenAsymmetricKeysThenJwkSetResponse() throws Exception { public void doFilterWhenAsymmetricKeysThenJwkSetResponse() throws Exception {
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
this.jwkList.add(rsaJwk);
ECKey ecJwk = TestJwks.DEFAULT_EC_JWK; ECKey ecJwk = TestJwks.DEFAULT_EC_JWK;
given(this.jwkSource.get(any(), any())).willReturn(Arrays.asList(rsaJwk, ecJwk)); this.jwkList.add(ecJwk);
String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@ -137,7 +140,7 @@ public class NimbusJwkSetEndpointFilterTests {
@Test @Test
public void doFilterWhenSymmetricKeysThenJwkSetResponseEmpty() throws Exception { public void doFilterWhenSymmetricKeysThenJwkSetResponseEmpty() throws Exception {
OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK; OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK;
given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(secretJwk)); this.jwkList.add(secretJwk);
String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);