From 09a363ba6784e15d9eae49d16c38925221aa64c5 Mon Sep 17 00:00:00 2001 From: Patryk Kostrzewa Date: Mon, 27 Apr 2020 17:10:49 +0200 Subject: [PATCH] Implement Client Credentials Authentication Fixes gh-39 --- .../OAuth2ClientAuthenticationProvider.java | 39 ++++- .../OAuth2ClientAuthenticationToken.java | 2 +- ...ltOAuth2ClientAuthenticationConverter.java | 72 +++++++++ .../web/OAuth2ClientAuthenticationFilter.java | 139 +++++++++++++++++- ...uth2ClientAuthenticationProviderTests.java | 90 ++++++++++++ ...th2ClientAuthenticationConverterTests.java | 99 +++++++++++++ ...OAuth2ClientAuthenticationFilterTests.java | 99 +++++++++++++ 7 files changed, 531 insertions(+), 9 deletions(-) create mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverter.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java index c3bdcc5..9aae27f 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java @@ -18,17 +18,52 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.util.Assert; /** + * An {@link AuthenticationProvider} implementation that validates {@link OAuth2ClientAuthenticationToken}s. + * * @author Joe Grandja + * @author Patryk Kostrzewa */ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvider { - private RegisteredClientRepository registeredClientRepository; + private final RegisteredClientRepository registeredClientRepository; + + /** + * @param registeredClientRepository + * the bean to lookup the client details from + */ + public OAuth2ClientAuthenticationProvider(RegisteredClientRepository registeredClientRepository) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + this.registeredClientRepository = registeredClientRepository; + } @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - return authentication; + String clientId = authentication.getName(); + if (authentication.getCredentials() == null) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } + + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); + // https://tools.ietf.org/html/rfc6749#section-2.4 + if (registeredClient == null) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } + + String presentedSecret = authentication.getCredentials() + .toString(); + if (!registeredClient.getClientSecret() + .equals(presentedSecret)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); + } + + return new OAuth2ClientAuthenticationToken(registeredClient); } @Override diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java index ea0bb2b..09ec98b 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java @@ -18,11 +18,11 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; - import java.util.Collections; /** * @author Joe Grandja + * @author Patryk Kostrzewa */ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverter.java new file mode 100644 index 0000000..86230a4 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverter.java @@ -0,0 +1,72 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.springframework.http.HttpHeaders; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.util.StringUtils; +import javax.servlet.http.HttpServletRequest; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +/** + * Converts from {@link HttpServletRequest} to {@link OAuth2ClientAuthenticationToken} that can be authenticated. + * + * @author Patryk Kostrzewa + */ +public class DefaultOAuth2ClientAuthenticationConverter implements AuthenticationConverter { + + private static final String AUTHENTICATION_SCHEME_BASIC = "Basic"; + + @Override + public OAuth2ClientAuthenticationToken convert(HttpServletRequest request) { + String header = request.getHeader(HttpHeaders.AUTHORIZATION); + + if (header == null) { + return null; + } + + header = header.trim(); + if (!StringUtils.startsWithIgnoreCase(header, AUTHENTICATION_SCHEME_BASIC)) { + return null; + } + + if (header.equalsIgnoreCase(AUTHENTICATION_SCHEME_BASIC)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } + + byte[] decoded; + try { + byte[] base64Token = header.substring(6) + .getBytes(StandardCharsets.UTF_8); + decoded = Base64.getDecoder() + .decode(base64Token); + } catch (IllegalArgumentException e) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } + + String token = new String(decoded, StandardCharsets.UTF_8); + String[] credentials = token.split(":"); + if (credentials.length != 2) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN)); + } + return new OAuth2ClientAuthenticationToken(credentials[0], credentials[1]); + } +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java index 1738b1a..fd1f4fa 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java @@ -15,9 +15,21 @@ */ package org.springframework.security.oauth2.server.authorization.web; +import org.springframework.http.MediaType; +import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; - import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -26,15 +38,130 @@ import java.io.IOException; /** * @author Joe Grandja + * @author Patryk Kostrzewa */ public class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter { - private AuthenticationManager authenticationManager; - @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { + public static final String DEFAULT_FILTER_PROCESSES_URL = "/oauth2/token"; + private final AuthenticationManager authenticationManager; + private final RequestMatcher requestMatcher; + private final OAuth2ErrorHttpMessageConverter errorMessageConverter = new OAuth2ErrorHttpMessageConverter(); + private AuthenticationSuccessHandler authenticationSuccessHandler; + private AuthenticationFailureHandler authenticationFailureHandler; + private AuthenticationConverter authenticationConverter = new DefaultOAuth2ClientAuthenticationConverter(); + /** + * Creates an instance which will authenticate against the supplied + * {@code AuthenticationManager}. + * + * @param authenticationManager + * the bean to submit authentication requests to + */ + public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager) { + this(authenticationManager, DEFAULT_FILTER_PROCESSES_URL); } + /** + * Creates an instance which will authenticate against the supplied + * {@code AuthenticationManager}. + * + *

+ * Configures default {@link RequestMatcher} verifying the provided endpoint. + * + * @param authenticationManager + * the bean to submit authentication requests to + * @param filterProcessesUrl + * the filterProcessesUrl to match request URI against + */ + public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, String filterProcessesUrl) { + this(authenticationManager, new AntPathRequestMatcher(filterProcessesUrl, "POST")); + } + + /** + * Creates an instance which will authenticate against the supplied + * {@code AuthenticationManager} and custom {@code RequestMatcher}. + * + * @param authenticationManager + * the bean to submit authentication requests to + * @param requestMatcher + * the {@code RequestMatcher} to match {@code HttpServletRequest} against + */ + public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, + RequestMatcher requestMatcher) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.authenticationManager = authenticationManager; + this.requestMatcher = requestMatcher; + this.authenticationSuccessHandler = this::defaultAuthenticationSuccessHandler; + this.authenticationFailureHandler = this::defaultAuthenticationFailureHandler; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (this.requestMatcher.matches(request)) { + Authentication authentication = this.authenticationConverter.convert(request); + if (authentication == null) { + filterChain.doFilter(request, response); + return; + } + try { + final Authentication result = this.authenticationManager.authenticate(authentication); + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, result); + } catch (OAuth2AuthenticationException failed) { + this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed); + return; + } + } + filterChain.doFilter(request, response); + } + + /** + * Used to define custom behaviour on a successful authentication. + * + * @param authenticationSuccessHandler + * the handler to be used + */ + public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + + /** + * Used to define custom behaviour on a failed authentication. + * + * @param authenticationFailureHandler + * the handler to be used + */ + public final void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + } + + /** + * Used to define custom {@link AuthenticationConverter}. + * + * @param authenticationConverter + * the converter to be used + */ + public final void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + private void defaultAuthenticationSuccessHandler(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) { + + SecurityContextHolder.getContext() + .setAuthentication(authentication); + } + + private void defaultAuthenticationFailureHandler(HttpServletRequest request, HttpServletResponse response, + AuthenticationException failed) throws IOException { + + SecurityContextHolder.clearContext(); + this.errorMessageConverter.write(((OAuth2AuthenticationException) failed).getError(), + MediaType.APPLICATION_JSON, new ServletServerHttpResponse(response)); + } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java new file mode 100644 index 0000000..099e4be --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import java.util.Collections; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2ClientAuthenticationProvider}. + * + * @author Patryk Kostrzewa + */ +public class OAuth2ClientAuthenticationProviderTests { + + private RegisteredClient registeredClient; + private RegisteredClientRepository registeredClientRepository; + private OAuth2ClientAuthenticationProvider authenticationProvider; + + @Before + public void setUp() { + this.registeredClient = TestRegisteredClients.registeredClient() + .build(); + this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient); + this.authenticationProvider = new OAuth2ClientAuthenticationProvider(this.registeredClientRepository); + } + + @Test + public void constructorWhenRegisteredClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientAuthenticationProvider(null)).isInstanceOf( + IllegalArgumentException.class); + } + + @Test + public void supportsWhenTypeOAuth2ClientAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2ClientAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenNullCredentialsThenThrowOAuth2AuthorizationException() { + assertThatThrownBy(() -> { + this.authenticationProvider.authenticate(new OAuth2ClientAuthenticationToken("id", null)); + }).isInstanceOf(OAuth2AuthenticationException.class); + } + + @Test + public void authenticateWhenNullRegisteredClientThenThrowOAuth2AuthorizationException() { + assertThatThrownBy(() -> { + this.authenticationProvider.authenticate(new OAuth2ClientAuthenticationToken("id", "secret")); + }).isInstanceOf(OAuth2AuthenticationException.class); + } + + @Test + public void authenticateWhenCredentialsNotEqualThenThrowOAuth2AuthorizationException() { + assertThatThrownBy(() -> { + this.authenticationProvider.authenticate( + new OAuth2ClientAuthenticationToken(this.registeredClient.getClientId(), + this.registeredClient.getClientSecret() + "_invalid")); + }).isInstanceOf(OAuth2AuthenticationException.class); + } + + @Test + public void authenticateWhenAuthenticationSuccessResponseThenReturnClientAuthenticationToken() { + OAuth2ClientAuthenticationToken authenticationResult = (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate( + new OAuth2ClientAuthenticationToken(this.registeredClient.getClientId(), + registeredClient.getClientSecret())); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getAuthorities()).isEqualTo(Collections.emptyList()); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java new file mode 100644 index 0000000..53fe560 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import java.util.Base64; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultOAuth2ClientAuthenticationConverter}. + * + * @author Patryk Kostrzewa + */ +public class DefaultOAuth2ClientAuthenticationConverterTests { + + private DefaultOAuth2ClientAuthenticationConverter converter; + + @Before + public void setup() { + this.converter = new DefaultOAuth2ClientAuthenticationConverter(); + } + + @Test + public void convertWhenConversionSuccessThenReturnClientAuthenticationToken() { + String token = "client:secret"; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic " + Base64.getEncoder() + .encodeToString(token.getBytes())); + + OAuth2ClientAuthenticationToken authentication = this.converter.convert(request); + + assertThat(authentication).isNotNull(); + assertThat(authentication.getName()).isEqualTo("client"); + } + + @Test + public void convertWithAuthorizationSchemeInMixedCaseWhenConversionSuccessThenReturnClientAuthenticationToken() { + String token = "client:secret"; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "BaSiC " + Base64.getEncoder() + .encodeToString(token.getBytes())); + + final OAuth2ClientAuthenticationToken authentication = this.converter.convert(request); + + assertThat(authentication).isNotNull(); + assertThat(authentication.getName()).isEqualTo("client"); + } + + @Test + public void convertWithIgnoringUnsupportedAuthenticationHeaderWhenConversionSuccessThenReturnClientAuthenticationToken() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Bearer unsupportedToken"); + + OAuth2ClientAuthenticationToken authentication = this.converter.convert(request); + + assertThat(authentication).isNull(); + } + + @Test + public void convertWhenNotValidTokenThenThrowOAuth2AuthenticationException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic " + Base64.getEncoder() + .encodeToString("client".getBytes())); + assertThatThrownBy(() -> this.converter.convert(request)).isInstanceOf(OAuth2AuthenticationException.class); + } + + @Test + public void convertWhenNotValidBase64ThenThrowOAuth2AuthenticationException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic NOT_VALID_BASE64"); + assertThatThrownBy(() -> this.converter.convert(request)).isInstanceOf(OAuth2AuthenticationException.class); + } + + @Test + public void convertWhenEmptyAuthenticationHeaderTokenThenThrowOAuth2AuthenticationException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic "); + assertThatThrownBy(() -> this.converter.convert(request)).isInstanceOf(OAuth2AuthenticationException.class); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java new file mode 100644 index 0000000..eb99394 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.web.util.matcher.RequestMatcher; +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OAuth2ClientAuthenticationFilter}. + * + * @author Patryk Kostrzewa + */ +public class OAuth2ClientAuthenticationFilterTests { + + private OAuth2ClientAuthenticationFilter filter; + private AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + private RequestMatcher requestMatcher = mock(RequestMatcher.class); + private FilterChain filterChain = mock(FilterChain.class); + private String filterProcessesUrl; + + @Before + public void setUp() { + this.filterProcessesUrl = "/oauth2/token"; + this.filter = new OAuth2ClientAuthenticationFilter(authenticationManager, requestMatcher); + } + + @Test + public void constructorManagerAndMatcherWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + new OAuth2ClientAuthenticationFilter(null, (RequestMatcher) null); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorManagerAndFilterUrlWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + new OAuth2ClientAuthenticationFilter(null, (String) null); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void doFilterWhenNotTokenRequestThenNextFilter() throws Exception { + this.filterProcessesUrl = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); + request.setServletPath(this.filterProcessesUrl); + MockHttpServletResponse response = new MockHttpServletResponse(); + + this.filter.doFilter(request, response, this.filterChain); + + verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthenticationRequestGetThenNotProcessed() throws Exception { + String requestUri = OAuth2ClientAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URL; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + + this.filter.doFilter(request, response, this.filterChain); + + verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthenticationIsNullThenNotProcessed() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); + request.setServletPath(this.filterProcessesUrl); + MockHttpServletResponse response = new MockHttpServletResponse(); + + this.filter.doFilter(request, response, this.filterChain); + + verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } +}