diff --git a/build.gradle b/build.gradle index 5a903ed..398c182 100644 --- a/build.gradle +++ b/build.gradle @@ -16,6 +16,8 @@ group = 'org.springframework.security.experimental' description = 'Spring Authorization Server' version = '0.0.1-SNAPSHOT' +ext['junit-jupiter.version'] = '5.4.0' + repositories { mavenCentral() } diff --git a/samples/boot/minimal/README.md b/samples/boot/minimal/README.md new file mode 100644 index 0000000..2a7ed4c --- /dev/null +++ b/samples/boot/minimal/README.md @@ -0,0 +1,12 @@ +## Minimal Authorization Server Sample + +#### How to run + +``` +./gradlew spring-authorization-server-samples-boot-minimal:bootRun +``` + +``` +curl http://localhost:8080/.well-known/jwk_uris +``` + diff --git a/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle b/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle index 5300d84..a48ee07 100644 --- a/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle +++ b/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle @@ -1,10 +1,19 @@ apply plugin: 'io.spring.convention.spring-sample-boot' dependencies { - implementation 'org.springframework.boot:spring-boot-starter' + implementation 'org.springframework.boot:spring-boot-starter-web' + implementation 'org.springframework.boot:spring-boot-starter-security' + + implementation 'com.nimbusds:oauth2-oidc-sdk' + testImplementation('org.springframework.boot:spring-boot-starter-test') { exclude group: 'org.junit.vintage', module: 'junit-vintage-engine' } + + testImplementation 'org.springframework.security:spring-security-test' + + testRuntime("org.junit.platform:junit-platform-runner") + testRuntime("org.junit.jupiter:junit-jupiter-engine") } test { diff --git a/samples/boot/minimal/src/main/java/sample/JwkSetEndpointFilter.java b/samples/boot/minimal/src/main/java/sample/JwkSetEndpointFilter.java new file mode 100644 index 0000000..e8db90e --- /dev/null +++ b/samples/boot/minimal/src/main/java/sample/JwkSetEndpointFilter.java @@ -0,0 +1,74 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sample; + +import static org.springframework.http.HttpMethod.GET; +import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; + +import java.io.IOException; +import java.io.Writer; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.util.UrlPathHelper; + +import com.nimbusds.jose.jwk.JWKSet; + +public class JwkSetEndpointFilter extends OncePerRequestFilter { + + static final String WELL_KNOWN_JWK_URIS = "/.well-known/jwk_uris"; + + private final RequestMatcher requestMatcher = new AntPathRequestMatcher(WELL_KNOWN_JWK_URIS, GET.name(), true, + new UrlPathHelper()); + + private final JWKSet jwkSet; + + public JwkSetEndpointFilter(JWKSet jwkSet) { + Assert.notNull(jwkSet, "jwkSet cannot be null"); + this.jwkSet = jwkSet; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (ifRequestMatches(request)) { + respond(response); + } else { + filterChain.doFilter(request, response); + } + } + + private void respond(HttpServletResponse response) throws IOException { + response.setContentType(APPLICATION_JSON_VALUE); + try (Writer writer = response.getWriter()) { + writer.write(jwkSet.toPublicJWKSet().toJSONObject().toJSONString()); + } + } + + private boolean ifRequestMatches(HttpServletRequest request) { + return this.requestMatcher.matches(request); + } + +} diff --git a/samples/boot/minimal/src/main/java/sample/SecurityConfig.java b/samples/boot/minimal/src/main/java/sample/SecurityConfig.java new file mode 100644 index 0000000..7c7bc6f --- /dev/null +++ b/samples/boot/minimal/src/main/java/sample/SecurityConfig.java @@ -0,0 +1,43 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sample; + +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.web.access.channel.ChannelProcessingFilter; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; + +@EnableWebSecurity +public class SecurityConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + http.addFilterBefore(new JwkSetEndpointFilter(generateJwkSet()), ChannelProcessingFilter.class); + } + + protected JWKSet generateJwkSet() throws JOSEException { + JWK jwk = new RSAKeyGenerator(2048).keyID("minimal-ASA").keyUse(KeyUse.SIGNATURE).generate(); + return new JWKSet(jwk); + } + +} diff --git a/samples/boot/minimal/src/test/java/sample/JwkSetEndpointFilterTest.java b/samples/boot/minimal/src/test/java/sample/JwkSetEndpointFilterTest.java new file mode 100644 index 0000000..aa71d2b --- /dev/null +++ b/samples/boot/minimal/src/test/java/sample/JwkSetEndpointFilterTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sample; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.verify; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static sample.JwkSetEndpointFilter.WELL_KNOWN_JWK_URIS; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import org.mockito.Mockito; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; + +@TestInstance(Lifecycle.PER_CLASS) +public class JwkSetEndpointFilterTest { + + private MockMvc mvc; + private JWKSet jwkSet; + private JWK jwk; + private JwkSetEndpointFilter filter; + + @BeforeAll + void setup() throws JOSEException { + this.jwk = new RSAKeyGenerator(2048).keyID("endpoint-test").keyUse(KeyUse.SIGNATURE).generate(); + this.jwkSet = new JWKSet(jwk); + this.filter = new JwkSetEndpointFilter(jwkSet); + this.mvc = MockMvcBuilders.standaloneSetup(new FakeController()).addFilters(filter).alwaysDo(print()).build(); + } + + @Test + void constructorWhenJsonWebKeySetIsNullThrowIllegalArgumentException() { + assertThatThrownBy(() -> new JwkSetEndpointFilter(null)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void doFilterWhenPathMatches() throws Exception { + String requestUri = WELL_KNOWN_JWK_URIS; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain, never()).doFilter(Mockito.any(HttpServletRequest.class), + Mockito.any(HttpServletResponse.class)); + } + + @Test + void doFilterWhenPathDoesNotMatch() throws Exception { + String requestUri = "/stuff/" + WELL_KNOWN_JWK_URIS; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain, only()).doFilter(Mockito.any(HttpServletRequest.class), + Mockito.any(HttpServletResponse.class)); + } + + @Test + void testResponseIfRequestMatches() throws Exception { + mvc.perform(get(WELL_KNOWN_JWK_URIS)).andDo(print()).andExpect(status().isOk()) + .andExpect(jsonPath("$.keys").isArray()).andExpect(jsonPath("$.keys").isNotEmpty()) + .andExpect(jsonPath("$.keys[0].kid").value(jwk.getKeyID())) + .andExpect(jsonPath("$.keys[0].kty").value(jwk.getKeyType().toString())); + } + + @Test + void testResponseIfNotRequestMatches() throws Exception { + mvc.perform(get("/fake")).andDo(print()).andExpect(status().isOk()) + .andExpect(content().string(is("fake"))); + } + + @RestController + class FakeController { + + @RequestMapping("/fake") + public String hello() { + return "fake"; + } + } +} diff --git a/samples/boot/minimal/src/test/java/sample/MinimalAuthorizationServerApplicationTests.java b/samples/boot/minimal/src/test/java/sample/MinimalAuthorizationServerApplicationTests.java index d6a6938..9f0a859 100644 --- a/samples/boot/minimal/src/test/java/sample/MinimalAuthorizationServerApplicationTests.java +++ b/samples/boot/minimal/src/test/java/sample/MinimalAuthorizationServerApplicationTests.java @@ -15,18 +15,29 @@ */ package sample; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.HttpStatus.OK; + import org.junit.jupiter.api.Test; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.context.ApplicationContext; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.boot.web.server.LocalServerPort; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; -import static org.assertj.core.api.Assertions.assertThat; - -@SpringBootTest +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) public class MinimalAuthorizationServerApplicationTests { + private RestTemplate rest = new RestTemplate(); + + @LocalServerPort + private int serverPort; + @Test - public void loadContext(ApplicationContext context) { - assertThat(context).isNotNull(); + void verifyJwkSetEndpointFilterAccessibleWithoutAuthentication() { + ResponseEntity responseEntity = rest.getForEntity( + "http://localhost:" + serverPort + JwkSetEndpointFilter.WELL_KNOWN_JWK_URIS, String.class); + assertThat(responseEntity.getStatusCode()).isEqualTo(OK); } }