From 7a5585f1fc4ec2ee0d26686626d59a852959187a Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 29 May 2020 11:05:00 -0400 Subject: [PATCH] Polish gh-72 --- .../OAuth2ClientAuthenticationProvider.java | 24 ++- .../OAuth2ClientAuthenticationToken.java | 45 ++++- ...ntSecretBasicAuthenticationConverter.java} | 55 ++++-- .../web/OAuth2ClientAuthenticationFilter.java | 135 +++++++------- ...uth2ClientAuthenticationProviderTests.java | 61 +++--- .../OAuth2ClientAuthenticationTokenTests.java | 71 +++++++ ...cretBasicAuthenticationConverterTests.java | 106 +++++++++++ ...th2ClientAuthenticationConverterTests.java | 99 ---------- ...OAuth2ClientAuthenticationFilterTests.java | 176 ++++++++++++++---- 9 files changed, 503 insertions(+), 269 deletions(-) rename core/src/main/java/org/springframework/security/oauth2/server/authorization/web/{DefaultOAuth2ClientAuthenticationConverter.java => ClientSecretBasicAuthenticationConverter.java} (53%) create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationTokenTests.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/web/ClientSecretBasicAuthenticationConverterTests.java delete mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java index 9aae27f..57e6a64 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java @@ -26,17 +26,22 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.util.Assert; /** - * An {@link AuthenticationProvider} implementation that validates {@link OAuth2ClientAuthenticationToken}s. + * An {@link AuthenticationProvider} implementation that validates {@link OAuth2ClientAuthenticationToken}'s. * * @author Joe Grandja * @author Patryk Kostrzewa + * @since 0.0.1 + * @see AuthenticationProvider + * @see OAuth2ClientAuthenticationToken + * @see RegisteredClientRepository */ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvider { private final RegisteredClientRepository registeredClientRepository; /** - * @param registeredClientRepository - * the bean to lookup the client details from + * Constructs an {@code OAuth2ClientAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients */ public OAuth2ClientAuthenticationProvider(RegisteredClientRepository registeredClientRepository) { Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); @@ -45,21 +50,14 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - String clientId = authentication.getName(); - if (authentication.getCredentials() == null) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); - } - + String clientId = authentication.getPrincipal().toString(); RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); - // https://tools.ietf.org/html/rfc6749#section-2.4 if (registeredClient == null) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } - String presentedSecret = authentication.getCredentials() - .toString(); - if (!registeredClient.getClientSecret() - .equals(presentedSecret)) { + String clientSecret = authentication.getCredentials().toString(); + if (!registeredClient.getClientSecret().equals(clientSecret)) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java index 09ec98b..efcabf2 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java @@ -15,40 +15,75 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; +import org.springframework.lang.Nullable; import org.springframework.security.authentication.AbstractAuthenticationToken; -import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; + import java.util.Collections; /** + * An {@link Authentication} implementation used for OAuth 2.0 Client Authentication. + * * @author Joe Grandja * @author Patryk Kostrzewa + * @since 0.0.1 + * @see AbstractAuthenticationToken + * @see RegisteredClient + * @see OAuth2ClientAuthenticationProvider */ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken { - private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String clientId; private String clientSecret; private RegisteredClient registeredClient; + /** + * Constructs an {@code OAuth2ClientAuthenticationToken} using the provided parameters. + * + * @param clientId the client identifier + * @param clientSecret the client secret + */ public OAuth2ClientAuthenticationToken(String clientId, String clientSecret) { super(Collections.emptyList()); + Assert.hasText(clientId, "clientId cannot be empty"); + Assert.hasText(clientSecret, "clientSecret cannot be empty"); this.clientId = clientId; this.clientSecret = clientSecret; } + /** + * Constructs an {@code OAuth2ClientAuthenticationToken} using the provided parameters. + * + * @param registeredClient the registered client + */ public OAuth2ClientAuthenticationToken(RegisteredClient registeredClient) { super(Collections.emptyList()); + Assert.notNull(registeredClient, "registeredClient cannot be null"); this.registeredClient = registeredClient; setAuthenticated(true); } + @Override + public Object getPrincipal() { + return this.registeredClient != null ? + this.registeredClient.getClientId() : + this.clientId; + } + @Override public Object getCredentials() { return this.clientSecret; } - @Override - public Object getPrincipal() { - return this.clientId; + /** + * Returns the {@link RegisteredClient registered client}. + * + * @return the {@link RegisteredClient} + */ + public @Nullable RegisteredClient getRegisteredClient() { + return this.registeredClient; } } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/ClientSecretBasicAuthenticationConverter.java similarity index 53% rename from core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverter.java rename to core/src/main/java/org/springframework/security/oauth2/server/authorization/web/ClientSecretBasicAuthenticationConverter.java index 86230a4..c18cfbb 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/ClientSecretBasicAuthenticationConverter.java @@ -16,57 +16,72 @@ package org.springframework.security.oauth2.server.authorization.web; import org.springframework.http.HttpHeaders; +import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.util.StringUtils; + import javax.servlet.http.HttpServletRequest; +import java.net.URLDecoder; import java.nio.charset.StandardCharsets; import java.util.Base64; /** - * Converts from {@link HttpServletRequest} to {@link OAuth2ClientAuthenticationToken} that can be authenticated. + * Attempts to extract HTTP Basic credentials from {@link HttpServletRequest} + * and then converts to an {@link OAuth2ClientAuthenticationToken} used for authenticating the client. * * @author Patryk Kostrzewa + * @author Joe Grandja + * @since 0.0.1 + * @see OAuth2ClientAuthenticationToken + * @see OAuth2ClientAuthenticationFilter */ -public class DefaultOAuth2ClientAuthenticationConverter implements AuthenticationConverter { - - private static final String AUTHENTICATION_SCHEME_BASIC = "Basic"; +public class ClientSecretBasicAuthenticationConverter implements AuthenticationConverter { @Override - public OAuth2ClientAuthenticationToken convert(HttpServletRequest request) { + public Authentication convert(HttpServletRequest request) { String header = request.getHeader(HttpHeaders.AUTHORIZATION); - if (header == null) { return null; } - header = header.trim(); - if (!StringUtils.startsWithIgnoreCase(header, AUTHENTICATION_SCHEME_BASIC)) { + String[] parts = header.split("\\s"); + if (!parts[0].equalsIgnoreCase("Basic")) { return null; } - if (header.equalsIgnoreCase(AUTHENTICATION_SCHEME_BASIC)) { + if (parts.length != 2) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); } - byte[] decoded; + byte[] decodedCredentials; try { - byte[] base64Token = header.substring(6) - .getBytes(StandardCharsets.UTF_8); - decoded = Base64.getDecoder() - .decode(base64Token); - } catch (IllegalArgumentException e) { + decodedCredentials = Base64.getDecoder().decode( + parts[1].getBytes(StandardCharsets.UTF_8)); + } catch (IllegalArgumentException ex) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST), ex); + } + + String credentialsString = new String(decodedCredentials, StandardCharsets.UTF_8); + String[] credentials = credentialsString.split(":", 2); + if (credentials.length != 2 || + !StringUtils.hasText(credentials[0]) || + !StringUtils.hasText(credentials[1])) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); } - String token = new String(decoded, StandardCharsets.UTF_8); - String[] credentials = token.split(":"); - if (credentials.length != 2) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN)); + String clientID; + String clientSecret; + try { + clientID = URLDecoder.decode(credentials[0], StandardCharsets.UTF_8.name()); + clientSecret = URLDecoder.decode(credentials[1], StandardCharsets.UTF_8.name()); + } catch (Exception ex) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST), ex); } - return new OAuth2ClientAuthenticationToken(credentials[0], credentials[1]); + + return new OAuth2ClientAuthenticationToken(clientID, clientSecret); } } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java index fd1f4fa..5a3cf49 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java @@ -15,21 +15,27 @@ */ package org.springframework.security.oauth2.server.authorization.web; -import org.springframework.http.MediaType; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; -import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -37,54 +43,29 @@ import javax.servlet.http.HttpServletResponse; import java.io.IOException; /** + * A {@code Filter} that processes an authentication request for an OAuth 2.0 Client. + * * @author Joe Grandja * @author Patryk Kostrzewa + * @since 0.0.1 + * @see AuthenticationManager + * @see OAuth2ClientAuthenticationProvider + * @see Section 2.3 Client Authentication + * @see Section 3.2.1 Token Endpoint Client Authentication */ public class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter { - - public static final String DEFAULT_FILTER_PROCESSES_URL = "/oauth2/token"; private final AuthenticationManager authenticationManager; private final RequestMatcher requestMatcher; - private final OAuth2ErrorHttpMessageConverter errorMessageConverter = new OAuth2ErrorHttpMessageConverter(); + private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); + private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler; private AuthenticationFailureHandler authenticationFailureHandler; - private AuthenticationConverter authenticationConverter = new DefaultOAuth2ClientAuthenticationConverter(); /** - * Creates an instance which will authenticate against the supplied - * {@code AuthenticationManager}. + * Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided parameters. * - * @param authenticationManager - * the bean to submit authentication requests to - */ - public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager) { - this(authenticationManager, DEFAULT_FILTER_PROCESSES_URL); - } - - /** - * Creates an instance which will authenticate against the supplied - * {@code AuthenticationManager}. - * - *

- * Configures default {@link RequestMatcher} verifying the provided endpoint. - * - * @param authenticationManager - * the bean to submit authentication requests to - * @param filterProcessesUrl - * the filterProcessesUrl to match request URI against - */ - public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, String filterProcessesUrl) { - this(authenticationManager, new AntPathRequestMatcher(filterProcessesUrl, "POST")); - } - - /** - * Creates an instance which will authenticate against the supplied - * {@code AuthenticationManager} and custom {@code RequestMatcher}. - * - * @param authenticationManager - * the bean to submit authentication requests to - * @param requestMatcher - * the {@code RequestMatcher} to match {@code HttpServletRequest} against + * @param authenticationManager the {@link AuthenticationManager} used for authenticating the client + * @param requestMatcher the {@link RequestMatcher} used for matching against the {@code HttpServletRequest} */ public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, RequestMatcher requestMatcher) { @@ -92,8 +73,9 @@ public class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter { Assert.notNull(requestMatcher, "requestMatcher cannot be null"); this.authenticationManager = authenticationManager; this.requestMatcher = requestMatcher; - this.authenticationSuccessHandler = this::defaultAuthenticationSuccessHandler; - this.authenticationFailureHandler = this::defaultAuthenticationFailureHandler; + this.authenticationConverter = new ClientSecretBasicAuthenticationConverter(); + this.authenticationSuccessHandler = this::onAuthenticationSuccess; + this.authenticationFailureHandler = this::onAuthenticationFailure; } @Override @@ -101,14 +83,12 @@ public class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter { throws ServletException, IOException { if (this.requestMatcher.matches(request)) { - Authentication authentication = this.authenticationConverter.convert(request); - if (authentication == null) { - filterChain.doFilter(request, response); - return; - } try { - final Authentication result = this.authenticationManager.authenticate(authentication); - this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, result); + Authentication authenticationRequest = this.authenticationConverter.convert(request); + if (authenticationRequest != null) { + Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest); + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult); + } } catch (OAuth2AuthenticationException failed) { this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed); return; @@ -118,10 +98,19 @@ public class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter { } /** - * Used to define custom behaviour on a successful authentication. + * Sets the {@link AuthenticationConverter} used for converting a {@link HttpServletRequest} to an {@link OAuth2ClientAuthenticationToken}. * - * @param authenticationSuccessHandler - * the handler to be used + * @param authenticationConverter used for converting a {@link HttpServletRequest} to an {@link OAuth2ClientAuthenticationToken} + */ + public final void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling successful authentications. + * + * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling successful authentications */ public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); @@ -129,39 +118,43 @@ public class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter { } /** - * Used to define custom behaviour on a failed authentication. + * Sets the {@link AuthenticationFailureHandler} used for handling failed authentications. * - * @param authenticationFailureHandler - * the handler to be used + * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling failed authentications */ public final void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); this.authenticationFailureHandler = authenticationFailureHandler; } - /** - * Used to define custom {@link AuthenticationConverter}. - * - * @param authenticationConverter - * the converter to be used - */ - public final void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { - Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); - this.authenticationConverter = authenticationConverter; - } - - private void defaultAuthenticationSuccessHandler(HttpServletRequest request, HttpServletResponse response, + private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { - SecurityContextHolder.getContext() - .setAuthentication(authentication); + SecurityContext context = SecurityContextHolder.createEmptyContext(); + context.setAuthentication(authentication); + SecurityContextHolder.setContext(context); } - private void defaultAuthenticationFailureHandler(HttpServletRequest request, HttpServletResponse response, + private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException { SecurityContextHolder.clearContext(); - this.errorMessageConverter.write(((OAuth2AuthenticationException) failed).getError(), - MediaType.APPLICATION_JSON, new ServletServerHttpResponse(response)); + + // TODO + // The authorization server MAY return an HTTP 401 (Unauthorized) status code + // to indicate which HTTP authentication schemes are supported. + // If the client attempted to authenticate via the "Authorization" request header field, + // the authorization server MUST respond with an HTTP 401 (Unauthorized) status code and + // include the "WWW-Authenticate" response header field + // matching the authentication scheme used by the client. + + OAuth2Error error = ((OAuth2AuthenticationException) failed).getError(); + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) { + httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED); + } else { + httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + } + this.errorHttpResponseConverter.write(error, null, httpResponse); } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java index 099e4be..03c39dd 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java @@ -18,11 +18,12 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import java.util.Collections; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -30,25 +31,25 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * Tests for {@link OAuth2ClientAuthenticationProvider}. * * @author Patryk Kostrzewa + * @author Joe Grandja */ public class OAuth2ClientAuthenticationProviderTests { - private RegisteredClient registeredClient; private RegisteredClientRepository registeredClientRepository; private OAuth2ClientAuthenticationProvider authenticationProvider; @Before public void setUp() { - this.registeredClient = TestRegisteredClients.registeredClient() - .build(); + this.registeredClient = TestRegisteredClients.registeredClient().build(); this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient); this.authenticationProvider = new OAuth2ClientAuthenticationProvider(this.registeredClientRepository); } @Test - public void constructorWhenRegisteredClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientAuthenticationProvider(null)).isInstanceOf( - IllegalArgumentException.class); + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientAuthenticationProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClientRepository cannot be null"); } @Test @@ -57,34 +58,36 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenNullCredentialsThenThrowOAuth2AuthorizationException() { - assertThatThrownBy(() -> { - this.authenticationProvider.authenticate(new OAuth2ClientAuthenticationToken("id", null)); - }).isInstanceOf(OAuth2AuthenticationException.class); + public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + this.registeredClient.getClientId() + "-invalid", this.registeredClient.getClientSecret()); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test - public void authenticateWhenNullRegisteredClientThenThrowOAuth2AuthorizationException() { - assertThatThrownBy(() -> { - this.authenticationProvider.authenticate(new OAuth2ClientAuthenticationToken("id", "secret")); - }).isInstanceOf(OAuth2AuthenticationException.class); + public void authenticateWhenInvalidClientSecretThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret() + "-invalid"); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); } @Test - public void authenticateWhenCredentialsNotEqualThenThrowOAuth2AuthorizationException() { - assertThatThrownBy(() -> { - this.authenticationProvider.authenticate( - new OAuth2ClientAuthenticationToken(this.registeredClient.getClientId(), - this.registeredClient.getClientSecret() + "_invalid")); - }).isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void authenticateWhenAuthenticationSuccessResponseThenReturnClientAuthenticationToken() { - OAuth2ClientAuthenticationToken authenticationResult = (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2ClientAuthenticationToken(this.registeredClient.getClientId(), - registeredClient.getClientSecret())); + public void authenticateWhenValidCredentialsThenAuthenticated() { + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + this.registeredClient.getClientId(), this.registeredClient.getClientSecret()); + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getAuthorities()).isEqualTo(Collections.emptyList()); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(this.registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials()).isNull(); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(this.registeredClient); } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationTokenTests.java new file mode 100644 index 0000000..2147626 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationTokenTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.junit.Test; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2ClientAuthenticationToken}. + * + * @author Joe Grandja + */ +public class OAuth2ClientAuthenticationTokenTests { + + @Test + public void constructorWhenClientIdNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken(null, "secret")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientId cannot be empty"); + } + + @Test + public void constructorWhenClientSecretNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken("clientId", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientSecret cannot be empty"); + } + + @Test + public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClient cannot be null"); + } + + @Test + public void constructorWhenClientCredentialsProvidedThenCreated() { + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken("clientId", "secret"); + assertThat(authentication.isAuthenticated()).isFalse(); + assertThat(authentication.getPrincipal().toString()).isEqualTo("clientId"); + assertThat(authentication.getCredentials()).isEqualTo("secret"); + assertThat(authentication.getRegisteredClient()).isNull(); + } + + @Test + public void constructorWhenRegisteredClientProvidedThenCreated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient); + assertThat(authentication.isAuthenticated()).isTrue(); + assertThat(authentication.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authentication.getCredentials()).isNull(); + assertThat(authentication.getRegisteredClient()).isEqualTo(registeredClient); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/ClientSecretBasicAuthenticationConverterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/ClientSecretBasicAuthenticationConverterTests.java new file mode 100644 index 0000000..39e521e --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/ClientSecretBasicAuthenticationConverterTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link ClientSecretBasicAuthenticationConverter}. + * + * @author Patryk Kostrzewa + * @author Joe Grandja + */ +public class ClientSecretBasicAuthenticationConverterTests { + private ClientSecretBasicAuthenticationConverter converter = new ClientSecretBasicAuthenticationConverter(); + + @Test + public void convertWhenAuthorizationHeaderEmptyThenReturnNull() { + MockHttpServletRequest request = new MockHttpServletRequest(); + Authentication authentication = this.converter.convert(request); + assertThat(authentication).isNull(); + } + + @Test + public void convertWhenAuthorizationHeaderNotBasicThenReturnNull() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Bearer token"); + Authentication authentication = this.converter.convert(request); + assertThat(authentication).isNull(); + } + + @Test + public void convertWhenAuthorizationHeaderBasicWithMissingCredentialsThenThrowOAuth2AuthenticationException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic "); + assertThatThrownBy(() -> this.converter.convert(request)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + } + + @Test + public void convertWhenAuthorizationHeaderBasicWithInvalidBase64ThenThrowOAuth2AuthenticationException() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic clientId:secret"); + assertThatThrownBy(() -> this.converter.convert(request)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + } + + @Test + public void convertWhenAuthorizationHeaderBasicWithMissingSecretThenThrowOAuth2AuthenticationException() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth("clientId", "")); + assertThatThrownBy(() -> this.converter.convert(request)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + } + + @Test + public void convertWhenAuthorizationHeaderBasicWithValidCredentialsThenReturnClientAuthenticationToken() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth("clientId", "secret")); + OAuth2ClientAuthenticationToken authentication = (OAuth2ClientAuthenticationToken) this.converter.convert(request); + assertThat(authentication.getPrincipal()).isEqualTo("clientId"); + assertThat(authentication.getCredentials()).isEqualTo("secret"); + } + + private static String encodeBasicAuth(String clientId, String secret) throws Exception { + clientId = URLEncoder.encode(clientId, StandardCharsets.UTF_8.name()); + secret = URLEncoder.encode(secret, StandardCharsets.UTF_8.name()); + String credentialsString = clientId + ":" + secret; + byte[] encodedBytes = Base64.getEncoder().encode(credentialsString.getBytes(StandardCharsets.UTF_8)); + return new String(encodedBytes, StandardCharsets.UTF_8); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java deleted file mode 100644 index 53fe560..0000000 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/DefaultOAuth2ClientAuthenticationConverterTests.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.web; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.http.HttpHeaders; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; -import java.util.Base64; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Tests for {@link DefaultOAuth2ClientAuthenticationConverter}. - * - * @author Patryk Kostrzewa - */ -public class DefaultOAuth2ClientAuthenticationConverterTests { - - private DefaultOAuth2ClientAuthenticationConverter converter; - - @Before - public void setup() { - this.converter = new DefaultOAuth2ClientAuthenticationConverter(); - } - - @Test - public void convertWhenConversionSuccessThenReturnClientAuthenticationToken() { - String token = "client:secret"; - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader(HttpHeaders.AUTHORIZATION, "Basic " + Base64.getEncoder() - .encodeToString(token.getBytes())); - - OAuth2ClientAuthenticationToken authentication = this.converter.convert(request); - - assertThat(authentication).isNotNull(); - assertThat(authentication.getName()).isEqualTo("client"); - } - - @Test - public void convertWithAuthorizationSchemeInMixedCaseWhenConversionSuccessThenReturnClientAuthenticationToken() { - String token = "client:secret"; - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader(HttpHeaders.AUTHORIZATION, "BaSiC " + Base64.getEncoder() - .encodeToString(token.getBytes())); - - final OAuth2ClientAuthenticationToken authentication = this.converter.convert(request); - - assertThat(authentication).isNotNull(); - assertThat(authentication.getName()).isEqualTo("client"); - } - - @Test - public void convertWithIgnoringUnsupportedAuthenticationHeaderWhenConversionSuccessThenReturnClientAuthenticationToken() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader(HttpHeaders.AUTHORIZATION, "Bearer unsupportedToken"); - - OAuth2ClientAuthenticationToken authentication = this.converter.convert(request); - - assertThat(authentication).isNull(); - } - - @Test - public void convertWhenNotValidTokenThenThrowOAuth2AuthenticationException() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader(HttpHeaders.AUTHORIZATION, "Basic " + Base64.getEncoder() - .encodeToString("client".getBytes())); - assertThatThrownBy(() -> this.converter.convert(request)).isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void convertWhenNotValidBase64ThenThrowOAuth2AuthenticationException() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader(HttpHeaders.AUTHORIZATION, "Basic NOT_VALID_BASE64"); - assertThatThrownBy(() -> this.converter.convert(request)).isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void convertWhenEmptyAuthenticationHeaderTokenThenThrowOAuth2AuthenticationException() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader(HttpHeaders.AUTHORIZATION, "Basic "); - assertThatThrownBy(() -> this.converter.convert(request)).isInstanceOf(OAuth2AuthenticationException.class); - } -} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java index eb99394..3d8d6fb 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java @@ -15,85 +15,197 @@ */ package org.springframework.security.oauth2.server.authorization.web; +import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; + import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; + +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; /** * Tests for {@link OAuth2ClientAuthenticationFilter}. * * @author Patryk Kostrzewa + * @author Joe Grandja */ public class OAuth2ClientAuthenticationFilterTests { - + private String filterProcessesUrl = "/oauth2/token"; + private AuthenticationManager authenticationManager; + private RequestMatcher requestMatcher; + private AuthenticationConverter authenticationConverter; private OAuth2ClientAuthenticationFilter filter; - private AuthenticationManager authenticationManager = mock(AuthenticationManager.class); - private RequestMatcher requestMatcher = mock(RequestMatcher.class); - private FilterChain filterChain = mock(FilterChain.class); - private String filterProcessesUrl; + private final HttpMessageConverter errorHttpResponseConverter = + new OAuth2ErrorHttpMessageConverter(); @Before public void setUp() { - this.filterProcessesUrl = "/oauth2/token"; - this.filter = new OAuth2ClientAuthenticationFilter(authenticationManager, requestMatcher); + this.authenticationManager = mock(AuthenticationManager.class); + this.requestMatcher = new AntPathRequestMatcher(this.filterProcessesUrl, HttpMethod.POST.name()); + this.filter = new OAuth2ClientAuthenticationFilter(this.authenticationManager, this.requestMatcher); + this.authenticationConverter = mock(AuthenticationConverter.class); + this.filter.setAuthenticationConverter(this.authenticationConverter); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); } @Test - public void constructorManagerAndMatcherWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - new OAuth2ClientAuthenticationFilter(null, (RequestMatcher) null); - }).isInstanceOf(IllegalArgumentException.class); + public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(null, this.requestMatcher)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationManager cannot be null"); } @Test - public void constructorManagerAndFilterUrlWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - new OAuth2ClientAuthenticationFilter(null, (String) null); - }).isInstanceOf(IllegalArgumentException.class); + public void constructorWhenRequestMatcherNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(this.authenticationManager, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("requestMatcher cannot be null"); } @Test - public void doFilterWhenNotTokenRequestThenNextFilter() throws Exception { - this.filterProcessesUrl = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); - request.setServletPath(this.filterProcessesUrl); - MockHttpServletResponse response = new MockHttpServletResponse(); - - this.filter.doFilter(request, response, this.filterChain); - - verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationConverter cannot be null"); } @Test - public void doFilterWhenAuthenticationRequestGetThenNotProcessed() throws Exception { - String requestUri = OAuth2ClientAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URL; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationSuccessHandler cannot be null"); + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationFailureHandler cannot be null"); + } + + @Test + public void doFilterWhenRequestDoesNotMatchThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, this.filterChain); + this.filter.doFilter(request, response, filterChain); - verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test - public void doFilterWhenAuthenticationIsNullThenNotProcessed() throws Exception { + public void doFilterWhenRequestMatchesAndEmptyCredentialsThenNotProcessed() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); request.setServletPath(this.filterProcessesUrl); MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, this.filterChain); + this.filter.doFilter(request, response, filterChain); - verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenRequestMatchesAndInvalidCredentialsThenInvalidRequestError() throws Exception { + when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenThrow( + new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST))); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); + request.setServletPath(this.filterProcessesUrl); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + } + + @Test + public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception { + when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn( + new OAuth2ClientAuthenticationToken("clientId", "invalid-secret")); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenThrow( + new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT))); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); + request.setServletPath(this.filterProcessesUrl); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + } + + @Test + public void doFilterWhenRequestMatchesAndValidCredentialsThenProcessed() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn( + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), registeredClient.getClientSecret())); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn( + new OAuth2ClientAuthenticationToken(registeredClient)); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl); + request.setServletPath(this.filterProcessesUrl); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + assertThat(authentication).isInstanceOf(OAuth2ClientAuthenticationToken.class); + assertThat(((OAuth2ClientAuthenticationToken) authentication).getRegisteredClient()).isEqualTo(registeredClient); + } + + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); } }