diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java index a371840..ea0bb2b 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java @@ -44,11 +44,11 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken @Override public Object getCredentials() { - return null; + return this.clientSecret; } @Override public Object getPrincipal() { - return null; + return this.clientId; } } diff --git a/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle b/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle index 70199fb..e1cface 100644 --- a/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle +++ b/samples/boot/minimal/spring-authorization-server-samples-boot-minimal.gradle @@ -4,6 +4,7 @@ dependencies { implementation 'org.springframework.boot:spring-boot-starter-web' implementation 'org.springframework.boot:spring-boot-starter-security' implementation 'com.nimbusds:oauth2-oidc-sdk' + implementation project(':spring-authorization-server-core') testImplementation('org.springframework.boot:spring-boot-starter-test') { exclude group: 'org.junit.vintage', module: 'junit-vintage-engine' diff --git a/samples/boot/minimal/src/main/java/sample/ClientCredentialsAuthenticationFilter.java b/samples/boot/minimal/src/main/java/sample/ClientCredentialsAuthenticationFilter.java new file mode 100644 index 0000000..5b0b6e4 --- /dev/null +++ b/samples/boot/minimal/src/main/java/sample/ClientCredentialsAuthenticationFilter.java @@ -0,0 +1,103 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package sample; + +import org.springframework.http.HttpMethod; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.web.filter.OncePerRequestFilter; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.Base64; + +import static java.nio.charset.StandardCharsets.UTF_8; + +/** + * A filter to perform client authentication for the Token Endpoint. + * + * See RFC-6749 2.3.1. + */ +public class ClientCredentialsAuthenticationFilter extends OncePerRequestFilter { + private final AuthenticationManager authenticationManager; + private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/oauth2/token", HttpMethod.POST.name()); + + public ClientCredentialsAuthenticationFilter(AuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) + throws ServletException, IOException { + + if (this.requestMatcher.matches(request)) { + String[] credentials = extractBasicAuthenticationCredentials(request); + String clientId = credentials[0]; + String clientSecret = credentials[1]; + + OAuth2ClientAuthenticationToken authenticationToken = new OAuth2ClientAuthenticationToken(clientId, clientSecret); + + Authentication authentication = this.authenticationManager.authenticate(authenticationToken); + + SecurityContextHolder.getContext().setAuthentication(authentication); + } + + chain.doFilter(request, response); + } + + private String[] extractBasicAuthenticationCredentials(HttpServletRequest request) { + String header = request.getHeader("Authorization"); + if (header != null && header.toLowerCase().startsWith("basic ")) { + return extractAndDecodeHeader(header, request); + } + throw new BadCredentialsException("Missing basic authentication header"); + } + + // Taken from BasicAuthenticationFilter (spring-security-web) + private String[] extractAndDecodeHeader(String header, HttpServletRequest request) { + + byte[] base64Token = header.substring(6).getBytes(UTF_8); + byte[] decoded; + try { + decoded = Base64.getDecoder().decode(base64Token); + } + catch (IllegalArgumentException e) { + throw new BadCredentialsException("Failed to decode basic authentication token"); + } + + String token = new String(decoded, getCredentialsCharset(request)); + + int delim = token.indexOf(":"); + + if (delim == -1) { + throw new BadCredentialsException("Invalid basic authentication token"); + } + return new String[] { token.substring(0, delim), token.substring(delim + 1) }; + } + + protected Charset getCredentialsCharset(HttpServletRequest httpRequest) { + return UTF_8; + } +} diff --git a/samples/boot/minimal/src/test/java/sample/ClientCredentialsAuthenticationFilterTests.java b/samples/boot/minimal/src/test/java/sample/ClientCredentialsAuthenticationFilterTests.java new file mode 100644 index 0000000..c4caf1b --- /dev/null +++ b/samples/boot/minimal/src/test/java/sample/ClientCredentialsAuthenticationFilterTests.java @@ -0,0 +1,113 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package sample; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockServletContext; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.util.Assert; + +import static java.net.URI.create; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Base64.getEncoder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; + +public class ClientCredentialsAuthenticationFilterTests { + private static final String CLIENT_ID = "myclientid"; + private static final String CLIENT_SECRET = "myclientsecret"; + private final AuthenticationManager authenticationManager = authentication -> { + Assert.isInstanceOf(OAuth2ClientAuthenticationToken.class, authentication); + OAuth2ClientAuthenticationToken token = (OAuth2ClientAuthenticationToken) authentication; + if (CLIENT_ID.equals(token.getPrincipal()) && CLIENT_SECRET.equals(token.getCredentials())) { + authentication.setAuthenticated(true); + return authentication; + } + throw new BadCredentialsException("Bad credentials"); + }; + private final ClientCredentialsAuthenticationFilter filter = new ClientCredentialsAuthenticationFilter(this.authenticationManager); + + @BeforeEach + public void setup() { + SecurityContextHolder.clearContext(); + } + + @Test + public void doFilterWhenUrlDoesNotMatchThenDontAuthenticate() throws Exception { + MockHttpServletRequest request = post(create("/someotherendpoint")).buildRequest(new MockServletContext()); + request.addHeader("Authorization", basicAuthHeader(CLIENT_ID, CLIENT_SECRET)); + + filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain()); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void doFilterWhenRequestMatchesThenAuthenticate() throws Exception { + MockHttpServletRequest request = post(create("/oauth2/token")).buildRequest(new MockServletContext()); + request.addHeader("Authorization", basicAuthHeader(CLIENT_ID, CLIENT_SECRET)); + + filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain()); + + assertThat(SecurityContextHolder.getContext().getAuthentication().isAuthenticated()).isTrue(); + } + + @Test + public void doFilterWhenBasicAuthenticationHeaderIsMissingThenThrowBadCredentialsException() { + MockHttpServletRequest request = post(create("/oauth2/token")).buildRequest(new MockServletContext()); + assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> + filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain())); + } + + @Test + public void doFilterWhenBasicAuthenticationHeaderHasInvalidSyntaxThenThrowBadCredentialsException() { + MockHttpServletRequest request = post(create("/oauth2/token")).buildRequest(new MockServletContext()); + request.addHeader("Authorization", "Basic invalid"); + + assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> + filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain())); + } + + @Test + public void doFilterWhenBasicAuthenticationProvidesIncorrectSecretThenThrowBadCredentialsException() { + MockHttpServletRequest request = post(create("/oauth2/token")).buildRequest(new MockServletContext()); + request.addHeader("Authorization", basicAuthHeader(CLIENT_ID, "incorrectsecret")); + + assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> + filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain())); + } + + @Test + public void doFilterWhenBasicAuthenticationProvidesIncorrectClientIdThenThrowBadCredentialsException() { + MockHttpServletRequest request = post(create("/oauth2/token")).buildRequest(new MockServletContext()); + request.addHeader("Authorization", basicAuthHeader("anotherclientid", CLIENT_SECRET)); + + assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> + filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain())); + } + + private static String basicAuthHeader(String clientId, String clientSecret) { + return "Basic " + getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes(UTF_8)); + } +}