From fbc98d511c69c628bfe475d5c1cab2504e1b786b Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 19 May 2020 16:55:50 -0400 Subject: [PATCH] Polish gh-77 --- .../InMemoryOAuth2AuthorizationService.java | 3 +- .../authorization/OAuth2Authorization.java | 3 +- .../OAuth2AuthorizationEndpointFilter.java | 382 +++++++++-------- .../OAuth2AuthorizationRequestConverter.java | 55 --- ...MemoryOAuth2AuthorizationServiceTests.java | 7 +- .../OAuth2AuthorizationTests.java | 5 +- .../client/TestRegisteredClients.java | 36 -- ...OAuth2AuthorizationEndpointFilterTest.java | 371 ---------------- ...Auth2AuthorizationEndpointFilterTests.java | 399 ++++++++++++++++++ 9 files changed, 625 insertions(+), 636 deletions(-) delete mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java delete mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 8a13b07..b87c590 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -15,6 +15,7 @@ */ package org.springframework.security.oauth2.server.authorization; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import java.util.List; @@ -65,7 +66,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) { if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { - return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue())); + return token.equals(authorization.getAttributes().get(OAuth2ParameterNames.class.getName().concat(".CODE"))); } else if (TokenType.ACCESS_TOKEN.equals(tokenType)) { return authorization.getAccessToken() != null && authorization.getAccessToken().getTokenValue().equals(token); diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index 0b5713d..2391c23 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.util.Assert; @@ -196,7 +197,7 @@ public class OAuth2Authorization implements Serializable { */ public OAuth2Authorization build() { Assert.hasText(this.principalName, "principalName cannot be empty"); - Assert.notNull(this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()), "authorization code cannot be null"); + Assert.notNull(this.attributes.get(OAuth2ParameterNames.class.getName().concat(".CODE")), "authorization code cannot be null"); OAuth2Authorization authorization = new OAuth2Authorization(); authorization.registeredClientId = this.registeredClientId; diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 270e92d..aac79d1 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -15,29 +15,21 @@ */ package org.springframework.security.oauth2.server.authorization.web; -import java.io.IOException; -import java.util.stream.Stream; - -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.web.DefaultRedirectStrategy; @@ -45,201 +37,257 @@ import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.UriComponentsBuilder; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + /** + * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, + * which handles the processing of the OAuth 2.0 Authorization Request. + * * @author Joe Grandja * @author Paurav Munshi * @since 0.0.1 + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + * @see OAuth2Authorization + * @see Section 4.1 Authorization Code Grant + * @see Section 4.1.1 Authorization Request */ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { + /** + * The default endpoint {@code URI} for authorization requests. + */ + public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize"; - private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; - - private Converter authorizationRequestConverter = new OAuth2AuthorizationRequestConverter(); - private RegisteredClientRepository registeredClientRepository; - private OAuth2AuthorizationService authorizationService; - private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(); - private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); - private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT); + private final RegisteredClientRepository registeredClientRepository; + private final OAuth2AuthorizationService authorizationService; + private final RequestMatcher authorizationEndpointMatcher; + private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + /** + * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + */ public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository, OAuth2AuthorizationService authorizationService) { - Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null."); - Assert.notNull(authorizationService, "authorizationService cannot be null."); + this(registeredClientRepository, authorizationService, DEFAULT_AUTHORIZATION_ENDPOINT_URI); + } + + /** + * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + * @param authorizationEndpointUri the endpoint {@code URI} for authorization requests + */ + public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService, String authorizationEndpointUri) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.hasText(authorizationEndpointUri, "authorizationEndpointUri cannot be empty"); this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; - } - - public final void setAuthorizationRequestConverter( - Converter authorizationRequestConverter) { - Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null"); - this.authorizationRequestConverter = authorizationRequestConverter; - } - - public final void setCodeGenerator(StringKeyGenerator codeGenerator) { - Assert.notNull(codeGenerator, "codeGenerator cannot be set to null"); - this.codeGenerator = codeGenerator; - } - - public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) { - Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null"); - this.authorizationRedirectStrategy = authorizationRedirectStrategy; - } - - public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) { - Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null"); - this.authorizationEndpointMatcher = authorizationEndpointMatcher; + this.authorizationEndpointMatcher = new AntPathRequestMatcher( + authorizationEndpointUri, HttpMethod.GET.name()); } @Override - protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { - boolean pathMatch = this.authorizationEndpointMatcher.matches(request); - String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); - boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType); - if (pathMatch && responseTypeMatch) { - return false; - }else { - return true; - } - } - - @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - RegisteredClient client = null; - OAuth2AuthorizationRequest authorizationRequest = null; - OAuth2Authorization authorization = null; - - try { - checkUserAuthenticated(); - Authentication auth = SecurityContextHolder.getContext().getAuthentication(); - client = fetchRegisteredClient(request); - - authorizationRequest = this.authorizationRequestConverter.convert(request); - validateAuthorizationRequest(authorizationRequest, client); - - String code = this.codeGenerator.generateKey(); - authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code); - this.authorizationService.save(authorization); - - String redirectUri = getRedirectUri(authorizationRequest, client); - sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code); + if (!this.authorizationEndpointMatcher.matches(request) || !isPrincipalAuthenticated()) { + filterChain.doFilter(request, response); + return; } - catch(OAuth2AuthorizationException authorizationException) { - OAuth2Error authorizationError = authorizationException.getError(); - if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) - || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) { - sendErrorInResponse(response, authorizationError); - } - else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE) - || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) { - String redirectUri = getRedirectUri(authorizationRequest, client); - sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri); - } - else { - throw new ServletException(authorizationException); +// TODO +// The authorization server validates the request to ensure that all +// required parameters are present and valid. If the request is valid, +// the authorization server authenticates the resource owner and obtains +// an authorization decision (by asking the resource owner or by +// establishing approval via other means). + + MultiValueMap parameters = getParameters(request); + String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE); + + // client_id (REQUIRED) + String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); + if (!StringUtils.hasText(clientId) || + parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); + sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect + return; + } + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); + if (registeredClient == null) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); + sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect + return; + } else if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { + OAuth2Error error = createError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID); + sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect + return; + } + + // redirect_uri (OPTIONAL) + String redirectUriParameter = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI); + if (StringUtils.hasText(redirectUriParameter)) { + if (!registeredClient.getRedirectUris().contains(redirectUriParameter) || + parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); + sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect + return; } + } else if (registeredClient.getRedirectUris().size() != 1) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); + sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect + return; } + String redirectUri = StringUtils.hasText(redirectUriParameter) ? + redirectUriParameter : registeredClient.getRedirectUris().iterator().next(); + + // response_type (REQUIRED) + String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE); + if (!StringUtils.hasText(responseType) || + parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) { + OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE); + sendErrorResponse(request, response, error, stateParameter, redirectUri); + return; + } else if (!responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) { + OAuth2Error error = createError(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2ParameterNames.RESPONSE_TYPE); + sendErrorResponse(request, response, error, stateParameter, redirectUri); + return; + } + + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + String code = this.codeGenerator.generateKey(); + OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest(request); + + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) + .principalName(principal.getName()) + .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), code) + .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest) + .build(); + + this.authorizationService.save(authorization); + +// TODO security checks for code parameter +// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks. +// A maximum authorization code lifetime of 10 minutes is RECOMMENDED. +// The client MUST NOT use the authorization code more than once. +// If an authorization code is used more than once, the authorization server MUST deny the request +// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code. +// The authorization code is bound to the client identifier and redirection URI. + + sendAuthorizationResponse(request, response, authorizationRequest, code, redirectUri); } - private void checkUserAuthenticated() { - Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication(); - if (currentAuth==null || !currentAuth.isAuthenticated()) { - throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); - } - } + private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response, + OAuth2AuthorizationRequest authorizationRequest, String code, String redirectUri) throws IOException { - private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { - String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); - if (StringUtils.isEmpty(clientId)) { - throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); - } - - RegisteredClient client = this.registeredClientRepository.findByClientId(clientId); - if (client==null) { - throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); - } - - boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes()) - .anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE)); - if (!isAuthorizationGrantAllowed) { - throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); - } - - return client; - - } - - private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client, - OAuth2AuthorizationRequest authorizationRequest, String code) { - OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client) - .principalName(auth.getPrincipal().toString()) - .attribute(TokenType.AUTHORIZATION_CODE.getValue(), code) - .attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes())) - .build(); - - return authorization; - } - - - private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { - String redirectUri = authorizationRequest.getRedirectUri(); - if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) { - throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); - } - if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) { - throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); - } - } - - private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { - return !StringUtils.isEmpty(authorizationRequest.getRedirectUri()) - ? authorizationRequest.getRedirectUri() - : client.getRedirectUris().stream().findFirst().get(); - } - - private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response, - OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException { - UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) + UriComponentsBuilder uriBuilder = UriComponentsBuilder + .fromUriString(redirectUri) .queryParam(OAuth2ParameterNames.CODE, code); - if (!StringUtils.isEmpty(authorizationRequest.getState())) { - redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + if (StringUtils.hasText(authorizationRequest.getState())) { + uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); } - - String finalRedirectUri = redirectUriBuilder.toUriString(); - this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri); + this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); } - private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { - int errorStatus = -1; - String errorCode = authorizationError.getErrorCode(); - if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) { - errorStatus=HttpStatus.FORBIDDEN.value(); + private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, + OAuth2Error error, String state, String redirectUri) throws IOException { + + if (redirectUri == null) { + // TODO Send default html error response + response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString()); + return; } - else { - errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value(); + + UriComponentsBuilder uriBuilder = UriComponentsBuilder + .fromUriString(redirectUri) + .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode()); + if (StringUtils.hasText(error.getDescription())) { + uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription()); } - response.sendError(errorStatus, authorizationError.getErrorCode()); + if (StringUtils.hasText(error.getUri())) { + uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri()); + } + if (StringUtils.hasText(state)) { + uriBuilder.queryParam(OAuth2ParameterNames.STATE, state); + } + this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); } - private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, - OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError, - String redirectUri) throws IOException { - UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) - .queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode()); + private static boolean isPrincipalAuthenticated() { + return isPrincipalAuthenticated(SecurityContextHolder.getContext().getAuthentication()); + } - if (!StringUtils.isEmpty(authorizationRequest.getState())) { - redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + private static boolean isPrincipalAuthenticated(Authentication principal) { + return principal != null && + !AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) && + principal.isAuthenticated(); + } + + private static OAuth2Error createError(String errorCode, String parameterName) { + return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, + "https://tools.ietf.org/html/rfc6749#section-4.1.2.1"); + } + + private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) { + MultiValueMap parameters = getParameters(request); + + Set scopes = Collections.emptySet(); + if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) { + String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); + scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); } - String finalRedirectURI = redirectUriBuilder.toUriString(); - this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); + return OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(request.getRequestURL().toString()) + .clientId(parameters.getFirst(OAuth2ParameterNames.CLIENT_ID)) + .redirectUri(parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)) + .scopes(scopes) + .state(parameters.getFirst(OAuth2ParameterNames.STATE)) + .additionalParameters(additionalParameters -> + parameters.entrySet().stream() + .filter(e -> !e.getKey().equals(OAuth2ParameterNames.RESPONSE_TYPE) && + !e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) && + !e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI) && + !e.getKey().equals(OAuth2ParameterNames.SCOPE) && + !e.getKey().equals(OAuth2ParameterNames.STATE)) + .forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0)))) + .build(); + } + + private static MultiValueMap getParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size()); + parameterMap.forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; } } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java deleted file mode 100644 index 619a742..0000000 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.web; - -import java.util.Arrays; -import java.util.Collections; -import java.util.LinkedHashSet; -import java.util.Set; - -import javax.servlet.http.HttpServletRequest; - -import org.springframework.core.convert.converter.Converter; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.StringUtils; - -/** - * @author Paurav Munshi - * @since 0.0.1 - * @see Converter - */ -public class OAuth2AuthorizationRequestConverter implements Converter { - - @Override - public OAuth2AuthorizationRequest convert(HttpServletRequest request) { - String scope = request.getParameter(OAuth2ParameterNames.SCOPE); - Set scopes = !StringUtils.isEmpty(scope) - ? new LinkedHashSet(Arrays.asList(scope.split(" "))) - : Collections.emptySet(); - - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID)) - .redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)) - .scopes(scopes) - .state(request.getParameter(OAuth2ParameterNames.STATE)) - .authorizationUri(request.getServletPath()) - .build(); - - return authorizationRequest; - } - -} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index d1fb762..f4bc5bd 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization; import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -61,7 +62,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void saveWhenAuthorizationProvidedThenSaved() { OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) + .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE) .build(); this.authorizationService.save(expectedAuthorization); @@ -88,7 +89,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) + .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE) .build(); this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization)); @@ -103,7 +104,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { "access-token", Instant.now().minusSeconds(60), Instant.now()); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) + .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE) .accessToken(accessToken) .build(); this.authorizationService.save(authorization); diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index a35efcc..dfd3d90 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization; import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -84,13 +85,13 @@ public class OAuth2AuthorizationTests { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) .accessToken(ACCESS_TOKEN) - .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE) + .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE) .build(); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN); assertThat(authorization.getAttributes()).containsExactly( - entry(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)); + entry(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)); } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java index 502fa48..5aa7bc0 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -46,40 +46,4 @@ public class TestRegisteredClients { .scope("profile") .scope("email"); } - - public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() { - return RegisteredClient.withId("valid_client_id") - .clientId("valid_client") - .clientSecret("valid_secret") - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .redirectUri("http://localhost:8080/test-application/callback") - .scope("openid") - .scope("profile") - .scope("email"); - } - - public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() { - return RegisteredClient.withId("valid_client_multi_uri_id") - .clientId("valid_client_multi_uri") - .clientSecret("valid_secret") - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .redirectUri("http://localhost:8080/test-application/callback") - .redirectUri("http://localhost:8080/another-test-application/callback") - .scope("openid") - .scope("profile") - .scope("email"); - } - - public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() { - return RegisteredClient.withId("valid_cc_client_id") - .clientId("valid_cc_client") - .clientSecret("valid_secret") - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .scope("openid") - .scope("profile") - .scope("email"); - } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java deleted file mode 100644 index 4b2e86a..0000000 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ /dev/null @@ -1,371 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.web; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.http.HttpStatus; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.crypto.keygen.StringKeyGenerator; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; -import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; - - -/** - * Tests for {@link OAuth2AuthorizationEndpointFilter}. - * - * @author Paurav Munshi - * @since 0.0.1 - */ - -public class OAuth2AuthorizationEndpointFilterTest { - - private static final String VALID_CLIENT = "valid_client"; - private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri"; - private static final String VALID_CC_CLIENT = "valid_cc_client"; - - private OAuth2AuthorizationEndpointFilter filter; - - private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); - private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class); - private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class); - private Authentication authentication = mock(Authentication.class); - - @Before - public void setUp() { - this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService); - this.filter.setCodeGenerator(this.codeGenerator); - - SecurityContextHolder.getContext().setAuthentication(this.authentication); - } - - @Test - public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { - assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { - assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { - assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { - assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { - assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { - assertThatThrownBy(() ->this.filter.setCodeGenerator(null)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(this.codeGenerator.generateKey()).thenReturn("sample_code"); - when(this.authentication.getPrincipal()).thenReturn("test-user"); - when(this.authentication.isAuthenticated()).thenReturn(true); - - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication).isAuthenticated(); - verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); - verify(this.authorizationService).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); - - } - - @Test - public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(this.codeGenerator.generateKey()).thenReturn("sample_code"); - when(this.authentication.getPrincipal()).thenReturn("test-user"); - when(this.authentication.isAuthenticated()).thenReturn(true); - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication).isAuthenticated(); - verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); - verify(this.authorizationService).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); - - } - - @Test - public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI); - request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build(); - when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient); - when(this.authentication.isAuthenticated()).thenReturn(true); - - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication, times(1)).isAuthenticated(); - verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); - - } - - @Test - public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(this.authentication.isAuthenticated()).thenReturn(true); - - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication, times(1)).isAuthenticated(); - verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); - - } - - @Test - public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build(); - when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient); - when(this.authentication.isAuthenticated()).thenReturn(true); - - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication, times(1)).isAuthenticated(); - verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); - - } - - @Test - public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.CLIENT_ID, ""); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - when(this.authentication.isAuthenticated()).thenReturn(true); - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication).isAuthenticated(); - verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); - assertThat(response.getContentAsString()).isEmpty(); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); - - } - - @Test - public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null); - when(this.codeGenerator.generateKey()).thenReturn("sample_code"); - when(this.authentication.isAuthenticated()).thenReturn(true); - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication).isAuthenticated(); - verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client"); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); - assertThat(response.getContentAsString()).isEmpty(); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); - - } - - @Test - public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - when(authentication.isAuthenticated()).thenReturn(false); - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication).isAuthenticated(); - verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); - assertThat(response.getContentAsString()).isEmpty(); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); - - } - - @Test - public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setServletPath("/custom/authorize"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); - spyFilter.doFilter(request, response, filterChain); - - verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); - verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); - } - - @Test - public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, ""); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); - spyFilter.doFilter(request, response, filterChain); - - verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); - verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); - verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - } - - @Test - public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception { - MockHttpServletRequest request = getValidMockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); - spyFilter.doFilter(request, response, filterChain); - - verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); - verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); - verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - } - - private MockHttpServletRequest getValidMockHttpServletRequest() { - - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT); - request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code"); - request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email"); - request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback"); - request.setParameter(OAuth2ParameterNames.STATE, "teststate"); - request.setServletPath("/oauth2/authorize"); - - return request; - - - } - -} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java new file mode 100644 index 0000000..af95436 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -0,0 +1,399 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.util.StringUtils; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2AuthorizationEndpointFilter}. + * + * @author Paurav Munshi + * @author Joe Grandja + * @since 0.0.1 + */ +public class OAuth2AuthorizationEndpointFilterTests { + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private OAuth2AuthorizationEndpointFilter filter; + private TestingAuthenticationToken authentication; + + @Before + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService); + this.authentication = new TestingAuthenticationToken("principalName", "password"); + this.authentication.setAuthenticated(true); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(this.authentication); + SecurityContextHolder.setContext(securityContext); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + } + + @Test + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClientRepository cannot be null"); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); + } + + @Test + public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationEndpointUri cannot be empty"); + } + + @Test + public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Exception { + String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthorizationRequestNotAuthenticatedThenNotProcessed() throws Exception { + String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.authentication.setAuthenticated(false); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.removeParameter(OAuth2ParameterNames.CLIENT_ID); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); + } + + @Test + public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); + } + + @Test + public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); + } + + @Test + public void doFilterWhenAuthorizationRequestAndClientNotAuthorizedToRequestCodeThenUnauthorizedClientError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantTypes(Set::clear) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id"); + } + + @Test + public void doFilterWhenAuthorizationRequestInvalidRedirectUriThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); + } + + @Test + public void doFilterWhenAuthorizationRequestMultipleRedirectUriThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); + } + + @Test + public void doFilterWhenAuthorizationRequestExcludesRedirectUriAndMultipleRegisteredThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().redirectUri("https://example2.com").build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.removeParameter(OAuth2ParameterNames.REDIRECT_URI); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); + } + + @Test + public void doFilterWhenAuthorizationRequestMissingResponseTypeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20response_type&" + + "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + + "state=state"); + } + + @Test + public void doFilterWhenAuthorizationRequestMultipleResponseTypeThenInvalidRequestError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=invalid_request&" + + "error_description=OAuth%202.0%20Parameter:%20response_type&" + + "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + + "state=state"); + } + + @Test + public void doFilterWhenAuthorizationRequestInvalidResponseTypeThenUnsupportedResponseTypeError() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + + "error=unsupported_response_type&" + + "error_description=OAuth%202.0%20Parameter:%20response_type&" + + "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + + "state=state"); + } + + @Test + public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) + .thenReturn(registeredClient); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + + verify(this.authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization authorization = authorizationCaptor.getValue(); + assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); + + String code = authorization.getAttribute(OAuth2ParameterNames.class.getName().concat(".CODE")); + assertThat(code).isNotNull(); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + assertThat(authorizationRequest).isNotNull(); + assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize"); + assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); + assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(registeredClient.getRedirectUris().iterator().next()); + assertThat(authorizationRequest.getScopes()).containsExactlyInAnyOrderElementsOf(registeredClient.getScopes()); + assertThat(authorizationRequest.getState()).isEqualTo("state"); + assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); + } + + private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { + String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); + + String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + + request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]); + request.addParameter(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + return request; + } +}