From 4c8f89af5ccf86e4cfced54702867a36075e820c Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 25 May 2020 09:23:18 -0400 Subject: [PATCH] Polish gh-79 --- .../authorization/OAuth2Authorization.java | 10 - .../OAuth2AccessTokenAuthenticationToken.java | 8 +- ...2AuthorizationCodeAuthenticationToken.java | 29 +- .../OAuth2AuthorizationEndpointFilter.java | 19 +- .../web/OAuth2EndpointUtils.java | 49 +++ .../web/OAuth2TokenEndpointFilter.java | 218 ++++++------ .../web/OAuth2TokenEndpointFilterTests.java | 319 ++++++++++-------- 7 files changed, 383 insertions(+), 269 deletions(-) create mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2EndpointUtils.java diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index c298052..f3dad82 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -33,7 +33,6 @@ import java.util.function.Consumer; * * @author Joe Grandja * @author Krisztian Toth - * @author Madhu Bhat * @since 0.0.1 * @see RegisteredClient * @see OAuth2AccessToken @@ -75,15 +74,6 @@ public class OAuth2Authorization implements Serializable { return this.accessToken; } - /** - * Sets the access token {@link OAuth2AccessToken} in the {@link OAuth2Authorization}. - * - * @param accessToken the access token - */ - public final void setAccessToken(OAuth2AccessToken accessToken) { - this.accessToken = accessToken; - } - /** * Returns the attribute(s) associated to the authorization. * diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java index dbf28ce..c9abd28 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java @@ -17,8 +17,8 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import java.util.Collections; @@ -28,7 +28,7 @@ import java.util.Collections; * @author Madhu Bhat */ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken { - private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private RegisteredClient registeredClient; private Authentication clientPrincipal; private OAuth2AccessToken accessToken; @@ -52,9 +52,9 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication } /** - * Returns the access token {@link OAuth2AccessToken}. + * Returns the {@link OAuth2AccessToken access token}. * - * @return the access token + * @return the {@link OAuth2AccessToken} */ public OAuth2AccessToken getAccessToken() { return this.accessToken; diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java index aaa567b..65e1609 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java @@ -18,7 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import org.springframework.lang.Nullable; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.security.oauth2.server.authorization.Version; import java.util.Collections; @@ -27,7 +27,7 @@ import java.util.Collections; * @author Madhu Bhat */ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { - private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String code; private Authentication clientPrincipal; private String clientId; @@ -37,26 +37,26 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti Authentication clientPrincipal, @Nullable String redirectUri) { super(Collections.emptyList()); this.code = code; - this.redirectUri = redirectUri; this.clientPrincipal = clientPrincipal; + this.redirectUri = redirectUri; } public OAuth2AuthorizationCodeAuthenticationToken(String code, String clientId, @Nullable String redirectUri) { super(Collections.emptyList()); this.code = code; - this.redirectUri = redirectUri; this.clientId = clientId; - } - - @Override - public Object getCredentials() { - return null; + this.redirectUri = redirectUri; } @Override public Object getPrincipal() { - return null; + return this.clientPrincipal != null ? this.clientPrincipal : this.clientId; + } + + @Override + public Object getCredentials() { + return ""; } /** @@ -67,4 +67,13 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti public String getCode() { return this.code; } + + /** + * Returns the redirectUri. + * + * @return the redirectUri + */ + public String getRedirectUri() { + return this.redirectUri; + } } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index b9af398..57e8f12 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -38,7 +38,6 @@ import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; @@ -53,7 +52,6 @@ import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.HashSet; -import java.util.Map; import java.util.Set; /** @@ -123,7 +121,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { // Validate the request to ensure that all required parameters are present and valid // --------------- - MultiValueMap parameters = getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE); // client_id (REQUIRED) @@ -258,7 +256,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { } private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) { - MultiValueMap parameters = getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); Set scopes = Collections.emptySet(); if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) { @@ -282,17 +280,4 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { .forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0)))) .build(); } - - private static MultiValueMap getParameters(HttpServletRequest request) { - Map parameterMap = request.getParameterMap(); - MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size()); - parameterMap.forEach((key, values) -> { - if (values.length > 0) { - for (String value : values) { - parameters.add(key, value); - } - } - }); - return parameters; - } } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2EndpointUtils.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2EndpointUtils.java new file mode 100644 index 0000000..adad007 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2EndpointUtils.java @@ -0,0 +1,49 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import javax.servlet.http.HttpServletRequest; +import java.util.Map; + +/** + * Utility methods for the OAuth 2.0 Protocol Endpoints. + * + * @author Joe Grandja + * @since 0.0.1 + * @see OAuth2AuthorizationEndpointFilter + * @see OAuth2TokenEndpointFilter + */ +final class OAuth2EndpointUtils { + + private OAuth2EndpointUtils() { + } + + static MultiValueMap getParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size()); + parameterMap.forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; + } +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 8bf8f28..51edd88 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -15,13 +15,11 @@ */ package org.springframework.security.oauth2.server.authorization.web; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.core.convert.converter.Converter; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; @@ -30,15 +28,17 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; @@ -47,145 +47,171 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; -import java.io.Writer; +import java.time.temporal.ChronoUnit; /** - * This {@code Filter} is used by the client to obtain an access token by presenting - * its authorization grant. + * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, + * which handles the processing of the OAuth 2.0 Access Token Request. * *

- * It converts the OAuth 2.0 Access Token Request to {@link OAuth2AuthorizationCodeAuthenticationToken}, - * which is then authenticated by the {@link AuthenticationManager} and gets back - * {@link OAuth2AccessTokenAuthenticationToken} which has the {@link OAuth2AccessToken} if the request - * was successfully authenticated. The {@link OAuth2AccessToken} is then updated in the in-flight {@link OAuth2Authorization} - * and sent back to the client. In case the authentication fails, an HTTP 401 (Unauthorized) response is returned. + * It converts the OAuth 2.0 Access Token Request to an {@link OAuth2AuthorizationCodeAuthenticationToken}, + * which is then authenticated by the {@link AuthenticationManager}. + * If the authentication succeeds, the {@link AuthenticationManager} returns an + * {@link OAuth2AccessTokenAuthenticationToken}, which contains + * the {@link OAuth2AccessToken} that is returned in the response. + * In case of any error, an {@link OAuth2Error} is returned in the response. * *

* By default, this {@code Filter} responds to access token requests - * at the {@code URI} {@code /oauth2/token} and {@code HttpMethod} {@code POST} - * using the default {@link AntPathRequestMatcher}. + * at the {@code URI} {@code /oauth2/token} and {@code HttpMethod} {@code POST}. * *

- * The default base {@code URI} {@code /oauth2/token} may be overridden - * via the constructor {@link #OAuth2TokenEndpointFilter(OAuth2AuthorizationService, AuthenticationManager, String)}. + * The default endpoint {@code URI} {@code /oauth2/token} may be overridden + * via the constructor {@link #OAuth2TokenEndpointFilter(AuthenticationManager, OAuth2AuthorizationService, String)}. * * @author Joe Grandja * @author Madhu Bhat + * @since 0.0.1 + * @see AuthenticationManager + * @see OAuth2AuthorizationService + * @see Section 4.1 Authorization Code Grant + * @see Section 4.1.3 Access Token Request */ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter { /** * The default endpoint {@code URI} for access token requests. */ - private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token"; + public static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token"; - private Converter authorizationGrantConverter = this::convert; - private AuthenticationManager authenticationManager; - private OAuth2AuthorizationService authorizationService; - private RequestMatcher uriMatcher; - private ObjectMapper objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL); + private final AuthenticationManager authenticationManager; + private final OAuth2AuthorizationService authorizationService; + private final RequestMatcher tokenEndpointMatcher; + private final Converter authorizationGrantAuthenticationConverter = + new AuthorizationCodeAuthenticationConverter(); + private final HttpMessageConverter accessTokenHttpResponseConverter = + new OAuth2AccessTokenResponseHttpMessageConverter(); + private final HttpMessageConverter errorHttpResponseConverter = + new OAuth2ErrorHttpMessageConverter(); /** * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters. * - * @param authorizationService the authorization service implementation - * @param authenticationManager the authentication manager implementation + * @param authenticationManager the authentication manager + * @param authorizationService the authorization service */ - public OAuth2TokenEndpointFilter(OAuth2AuthorizationService authorizationService, AuthenticationManager authenticationManager) { - Assert.notNull(authorizationService, "authorizationService cannot be null"); - Assert.notNull(authenticationManager, "authenticationManager cannot be null"); - this.authenticationManager = authenticationManager; - this.authorizationService = authorizationService; - this.uriMatcher = new AntPathRequestMatcher(DEFAULT_TOKEN_ENDPOINT_URI, HttpMethod.POST.name()); + public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager, + OAuth2AuthorizationService authorizationService) { + this(authenticationManager, authorizationService, DEFAULT_TOKEN_ENDPOINT_URI); } /** * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters. * - * @param authorizationService the authorization service implementation - * @param authenticationManager the authentication manager implementation - * @param tokenEndpointUri the token endpoint's uri + * @param authenticationManager the authentication manager + * @param authorizationService the authorization service + * @param tokenEndpointUri the endpoint {@code URI} for access token requests */ - public OAuth2TokenEndpointFilter(OAuth2AuthorizationService authorizationService, AuthenticationManager authenticationManager, - String tokenEndpointUri) { - Assert.notNull(authorizationService, "authorizationService cannot be null"); + public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager, + OAuth2AuthorizationService authorizationService, String tokenEndpointUri) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); Assert.hasText(tokenEndpointUri, "tokenEndpointUri cannot be empty"); this.authenticationManager = authenticationManager; this.authorizationService = authorizationService; - this.uriMatcher = new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name()); + this.tokenEndpointMatcher = new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name()); } @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (uriMatcher.matches(request)) { - try { - if (validateAccessTokenRequest(request)) { - OAuth2AuthorizationCodeAuthenticationToken authCodeAuthToken = - (OAuth2AuthorizationCodeAuthenticationToken) authorizationGrantConverter.convert(request); - OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken = - (OAuth2AccessTokenAuthenticationToken) authenticationManager.authenticate(authCodeAuthToken); - if (accessTokenAuthenticationToken.isAuthenticated()) { - OAuth2Authorization authorization = authorizationService - .findByTokenAndTokenType(authCodeAuthToken.getCode(), TokenType.AUTHORIZATION_CODE); - authorization.setAccessToken(accessTokenAuthenticationToken.getAccessToken()); - authorizationService.save(authorization); - writeSuccessResponse(response, accessTokenAuthenticationToken.getAccessToken()); - } else { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); - } - } - } catch (OAuth2AuthenticationException exception) { - SecurityContextHolder.clearContext(); - writeFailureResponse(response, exception.getError()); - } - } else { + + if (!this.tokenEndpointMatcher.matches(request)) { filterChain.doFilter(request, response); + return; + } + + try { + Authentication authorizationGrantAuthentication = + this.authorizationGrantAuthenticationConverter.convert(request); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication); + sendAccessTokenResponse(response, accessTokenAuthentication.getAccessToken()); + } catch (OAuth2AuthenticationException ex) { + SecurityContextHolder.clearContext(); + sendErrorResponse(response, ex.getError()); } } - private boolean validateAccessTokenRequest(HttpServletRequest request) { - if (StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.CODE)) - || StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)) - || StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.GRANT_TYPE))) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); - } else if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(request.getParameter(OAuth2ParameterNames.GRANT_TYPE))) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE)); + private void sendAccessTokenResponse(HttpServletResponse response, OAuth2AccessToken accessToken) throws IOException { + OAuth2AccessTokenResponse.Builder builder = + OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) + .tokenType(accessToken.getTokenType()) + .scopes(accessToken.getScopes()); + if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) { + builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt())); } - return true; + OAuth2AccessTokenResponse accessTokenResponse = builder.build(); + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse); } - private OAuth2AuthorizationCodeAuthenticationToken convert(HttpServletRequest request) { - Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - return new OAuth2AuthorizationCodeAuthenticationToken( - request.getParameter(OAuth2ParameterNames.CODE), - clientPrincipal, - request.getParameter(OAuth2ParameterNames.REDIRECT_URI) - ); + private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + this.errorHttpResponseConverter.write(error, null, httpResponse); } - private void writeSuccessResponse(HttpServletResponse response, OAuth2AccessToken body) throws IOException { - try (Writer out = response.getWriter()) { - response.setStatus(HttpStatus.OK.value()); - response.setContentType(MediaType.APPLICATION_JSON_VALUE); - response.setCharacterEncoding("UTF-8"); - response.setHeader(HttpHeaders.CACHE_CONTROL, "no-store"); - response.setHeader(HttpHeaders.PRAGMA, "no-cache"); - out.write(objectMapper.writeValueAsString(body)); - } + private static OAuth2AuthenticationException throwError(String errorCode, String parameterName) { + OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, + "https://tools.ietf.org/html/rfc6749#section-5.2"); + throw new OAuth2AuthenticationException(error); } - private void writeFailureResponse(HttpServletResponse response, OAuth2Error error) throws IOException { - try (Writer out = response.getWriter()) { - if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_CLIENT)) { - response.setStatus(HttpStatus.UNAUTHORIZED.value()); - } else { - response.setStatus(HttpStatus.BAD_REQUEST.value()); + private static class AuthorizationCodeAuthenticationConverter implements Converter { + + @Override + public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + + // grant_type (REQUIRED) + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); + if (!StringUtils.hasText(grantType) || + parameters.get(OAuth2ParameterNames.GRANT_TYPE).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE); } - response.setContentType(MediaType.APPLICATION_JSON_VALUE); - response.setCharacterEncoding("UTF-8"); - out.write(objectMapper.writeValueAsString(error)); + if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) { + throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE); + } + + // client_id (REQUIRED) + String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); + Authentication clientPrincipal = null; + if (StringUtils.hasText(clientId)) { + if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); + } + } else { + clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); + } + + // code (REQUIRED) + String code = parameters.getFirst(OAuth2ParameterNames.CODE); + if (!StringUtils.hasText(code) || + parameters.get(OAuth2ParameterNames.CODE).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CODE); + } + + // redirect_uri (REQUIRED) + // Required only if the "redirect_uri" parameter was included in the authorization request + String redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI); + if (StringUtils.hasText(redirectUri) && + parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); + } + + return clientPrincipal != null ? + new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) : + new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri); } } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index 0deddbe..85e2fa6 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -15,36 +15,47 @@ */ package org.springframework.security.oauth2.server.authorization.web; +import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.springframework.http.HttpHeaders; +import org.mockito.ArgumentCaptor; import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.time.Duration; import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; +import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @@ -53,178 +64,222 @@ import static org.mockito.Mockito.when; * Tests for {@link OAuth2TokenEndpointFilter}. * * @author Madhu Bhat + * @author Joe Grandja */ public class OAuth2TokenEndpointFilterTests { - + private AuthenticationManager authenticationManager; + private OAuth2AuthorizationService authorizationService; private OAuth2TokenEndpointFilter filter; - private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); - private AuthenticationManager authenticationManager = mock(AuthenticationManager.class); - private FilterChain filterChain = mock(FilterChain.class); - private String requestUri; - private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); - private static final String PRINCIPAL_NAME = "principal"; - private static final String AUTHORIZATION_CODE = "code"; + private final HttpMessageConverter errorHttpResponseConverter = + new OAuth2ErrorHttpMessageConverter(); + private final HttpMessageConverter accessTokenHttpResponseConverter = + new OAuth2AccessTokenResponseHttpMessageConverter(); @Before public void setUp() { - this.filter = new OAuth2TokenEndpointFilter(this.authorizationService, this.authenticationManager); - this.requestUri = "/oauth2/token"; + this.authenticationManager = mock(AuthenticationManager.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.filter = new OAuth2TokenEndpointFilter(this.authenticationManager, this.authorizationService); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); } @Test - public void constructorServiceAndManagerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - new OAuth2TokenEndpointFilter(null, null); - }).isInstanceOf(IllegalArgumentException.class); + public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(null, this.authorizationService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationManager cannot be null"); } @Test - public void constructorServiceAndManagerAndEndpointWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - new OAuth2TokenEndpointFilter(null, null, null); - }).isInstanceOf(IllegalArgumentException.class); + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(this.authenticationManager, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); } @Test - public void doFilterWhenNotTokenRequestThenNextFilter() throws Exception { - this.requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", this.requestUri); - request.setServletPath(this.requestUri); + public void constructorWhenTokenEndpointUriNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(this.authenticationManager, this.authorizationService, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("tokenEndpointUri cannot be empty"); + } + + @Test + public void doFilterWhenNotTokenRequestThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, this.filterChain); + this.filter.doFilter(request, response, filterChain); - verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test - public void doFilterWhenAccessTokenRequestWithoutGrantTypeThenRespondWithBadRequest() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode"); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri"); - request.setServletPath(this.requestUri); + public void doFilterWhenTokenRequestGetThenNotProcessed() throws Exception { + String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, this.filterChain); + this.filter.doFilter(request, response, filterChain); - verifyNoInteractions(this.filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}"); + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test - public void doFilterWhenAccessTokenRequestWithoutCodeThenRespondWithBadRequest() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri); - request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType"); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri"); - request.setServletPath(this.requestUri); - MockHttpServletResponse response = new MockHttpServletResponse(); - - this.filter.doFilter(request, response, this.filterChain); - - verifyNoInteractions(this.filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}"); + public void doFilterWhenTokenRequestMissingGrantTypeThenInvalidRequestError() throws Exception { + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(OAuth2ParameterNames.GRANT_TYPE)); } @Test - public void doFilterWhenAccessTokenRequestWithoutRedirectUriThenRespondWithBadRequest() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri); - request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType"); - request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode"); - request.setServletPath(this.requestUri); - MockHttpServletResponse response = new MockHttpServletResponse(); - - this.filter.doFilter(request, response, this.filterChain); - - verifyNoInteractions(this.filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}"); + public void doFilterWhenTokenRequestMultipleGrantTypeThenInvalidRequestError() throws Exception { + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue())); } @Test - public void doFilterWhenAccessTokenRequestWithoutAuthCodeGrantTypeThenRespondWithBadRequest() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri); - request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType"); - request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode"); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri"); - request.setServletPath(this.requestUri); - MockHttpServletResponse response = new MockHttpServletResponse(); - - this.filter.doFilter(request, response, this.filterChain); - - verifyNoInteractions(this.filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"unsupported_grant_type\"}"); + public void doFilterWhenTokenRequestInvalidGrantTypeThenUnsupportedGrantTypeError() throws Exception { + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, + request -> request.setParameter(OAuth2ParameterNames.GRANT_TYPE, "invalid-grant-type")); } @Test - public void doFilterWhenAccessTokenRequestIsNotAuthenticatedThenRespondWithUnauthorized() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri); - request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode"); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri"); - request.setServletPath(this.requestUri); - MockHttpServletResponse response = new MockHttpServletResponse(); - Authentication clientPrincipal = mock(Authentication.class); - RegisteredClient registeredClient = mock(RegisteredClient.class); + public void doFilterWhenTokenRequestMultipleClientIdThenInvalidRequestError() throws Exception { + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, + request -> { + request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1"); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"); + }); + } + @Test + public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception { + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(OAuth2ParameterNames.CODE)); + } + + @Test + public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception { + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.CODE, "code-2")); + } + + @Test + public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception { + doFilterWhenTokenRequestInvalidParameterThenError( + OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com")); + } + + @Test + public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "testToken", Instant.now().minusSeconds(60), Instant.now()); - OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) - .principalName(PRINCIPAL_NAME) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) - .build(); - OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken = - new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken); - accessTokenAuthenticationToken.setAuthenticated(false); + OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofHours(1)), + new HashSet<>(Arrays.asList("scope1", "scope2"))); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + new OAuth2AccessTokenAuthenticationToken( + registeredClient, clientPrincipal, accessToken); - when(this.authorizationService.findByTokenAndTokenType(anyString(), any(TokenType.class))).thenReturn(authorization); - when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(accessTokenAuthenticationToken); + when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication); - this.filter.doFilter(request, response, this.filterChain); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); - verifyNoInteractions(this.filterChain); - verify(this.authorizationService, times(0)).save(authorization); - verify(this.authenticationManager, times(1)).authenticate(any(Authentication.class)); - assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); - assertThat(response.getContentAsString()) - .isEqualTo("{\"errorCode\":\"invalid_client\"}"); - } - - @Test - public void doFilterWhenValidAccessTokenRequestThenRespondWithAccessToken() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri); - request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode"); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri"); - request.setServletPath(this.requestUri); + MockHttpServletRequest request = createTokenRequest(registeredClient); MockHttpServletResponse response = new MockHttpServletResponse(); - Authentication clientPrincipal = mock(Authentication.class); - RegisteredClient registeredClient = mock(RegisteredClient.class); + FilterChain filterChain = mock(FilterChain.class); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "testToken", Instant.now().minusSeconds(60), Instant.now()); - OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) - .principalName(PRINCIPAL_NAME) - .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) - .build(); - OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken = - new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken); - accessTokenAuthenticationToken.setAuthenticated(true); + this.filter.doFilter(request, response, filterChain); - when(this.authorizationService.findByTokenAndTokenType(anyString(), any(TokenType.class))).thenReturn(authorization); - when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(accessTokenAuthenticationToken); + verifyNoInteractions(filterChain); - this.filter.doFilter(request, response, this.filterChain); + ArgumentCaptor authorizationCodeAuthenticationCaptor = + ArgumentCaptor.forClass(OAuth2AuthorizationCodeAuthenticationToken.class); + verify(this.authenticationManager).authenticate(authorizationCodeAuthenticationCaptor.capture()); + + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = + authorizationCodeAuthenticationCaptor.getValue(); + assertThat(authorizationCodeAuthentication.getCode()).isEqualTo( + request.getParameter(OAuth2ParameterNames.CODE)); + assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo( + request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); - verifyNoInteractions(this.filterChain); - verify(this.authorizationService, times(1)).save(authorization); - verify(this.authenticationManager, times(1)).authenticate(any(Authentication.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); - assertThat(response.getContentAsString()).contains("\"tokenValue\":\"testToken\""); - assertThat(response.getContentAsString()).contains("\"tokenType\":{\"value\":\"Bearer\"}"); - assertThat(response.getHeader(HttpHeaders.CACHE_CONTROL)).isEqualTo("no-store"); - assertThat(response.getHeader(HttpHeaders.PRAGMA)).isEqualTo("no-cache"); + OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); + + OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken(); + assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType()); + assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue()); + assertThat(accessTokenResult.getIssuedAt()).isBetween( + accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1)); + assertThat(accessTokenResult.getExpiresAt()).isBetween( + accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1)); + assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); + } + + private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode, + Consumer requestConsumer) throws Exception { + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + + MockHttpServletRequest request = createTokenRequest(registeredClient); + requestConsumer.accept(request); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(errorCode); + assertThat(error.getDescription()).isEqualTo("OAuth 2.0 Parameter: " + parameterName); + } + + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); + } + + private OAuth2AccessTokenResponse readAccessTokenResponse(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); + } + + private static MockHttpServletRequest createTokenRequest(RegisteredClient registeredClient) { + String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); + + String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]); + + return request; } }