2020-04-24 03:45:34 +07:00
/ *
2021-01-18 21:31:06 +07:00
* Copyright 2020 - 2021 the original author or authors .
2020-04-24 03:45:34 +07:00
*
* Licensed under the Apache License , Version 2 . 0 ( the " License " ) ;
* you may not use this file except in compliance with the License .
* You may obtain a copy of the License at
*
* https : //www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS ,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
* See the License for the specific language governing permissions and
* limitations under the License .
* /
package org.springframework.security.oauth2.server.authorization.web ;
2021-01-18 21:31:06 +07:00
import java.io.IOException ;
import java.nio.charset.StandardCharsets ;
import java.time.Instant ;
import java.time.temporal.ChronoUnit ;
import java.util.Arrays ;
import java.util.Base64 ;
import java.util.Collections ;
import java.util.HashSet ;
import java.util.List ;
import java.util.Set ;
import javax.servlet.FilterChain ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
2020-05-20 03:55:50 +07:00
import org.springframework.http.HttpMethod ;
2020-04-30 10:29:20 +07:00
import org.springframework.http.HttpStatus ;
2020-09-22 22:57:50 +07:00
import org.springframework.http.MediaType ;
2020-05-20 03:55:50 +07:00
import org.springframework.security.authentication.AnonymousAuthenticationToken ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.context.SecurityContextHolder ;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator ;
2020-04-24 03:45:34 +07:00
import org.springframework.security.crypto.keygen.StringKeyGenerator ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.oauth2.core.AuthorizationGrantType ;
import org.springframework.security.oauth2.core.OAuth2Error ;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes ;
2020-04-24 03:45:34 +07:00
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest ;
2020-05-20 03:55:50 +07:00
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
2020-06-23 02:35:01 +07:00
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames ;
2020-12-04 22:02:07 +07:00
import org.springframework.security.oauth2.core.oidc.OidcScopes ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization ;
2020-05-21 16:46:59 +07:00
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames ;
2020-04-24 03:45:34 +07:00
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService ;
2020-09-22 22:57:50 +07:00
import org.springframework.security.oauth2.server.authorization.TokenType ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient ;
2020-04-24 03:45:34 +07:00
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository ;
2020-10-23 01:03:24 +07:00
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.web.DefaultRedirectStrategy ;
import org.springframework.security.web.RedirectStrategy ;
2020-12-04 22:02:07 +07:00
import org.springframework.security.web.util.matcher.AndRequestMatcher ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.web.util.matcher.AntPathRequestMatcher ;
2020-12-04 22:02:07 +07:00
import org.springframework.security.web.util.matcher.NegatedRequestMatcher ;
import org.springframework.security.web.util.matcher.OrRequestMatcher ;
2020-04-30 10:29:20 +07:00
import org.springframework.security.web.util.matcher.RequestMatcher ;
import org.springframework.util.Assert ;
2020-09-22 22:57:50 +07:00
import org.springframework.util.CollectionUtils ;
2020-05-20 03:55:50 +07:00
import org.springframework.util.MultiValueMap ;
2020-04-30 10:29:20 +07:00
import org.springframework.util.StringUtils ;
2020-04-24 03:45:34 +07:00
import org.springframework.web.filter.OncePerRequestFilter ;
2020-04-30 10:29:20 +07:00
import org.springframework.web.util.UriComponentsBuilder ;
2020-04-24 03:45:34 +07:00
/ * *
2020-05-20 03:55:50 +07:00
* A { @code Filter } for the OAuth 2 . 0 Authorization Code Grant ,
* which handles the processing of the OAuth 2 . 0 Authorization Request .
*
2020-04-24 03:45:34 +07:00
* @author Joe Grandja
2020-04-30 10:29:20 +07:00
* @author Paurav Munshi
2020-10-01 02:20:38 +07:00
* @author Daniel Garnier - Moiroux
2020-04-30 10:29:20 +07:00
* @since 0 . 0 . 1
2020-05-20 03:55:50 +07:00
* @see RegisteredClientRepository
* @see OAuth2AuthorizationService
* @see OAuth2Authorization
* @see < a target = " _blank " href = " https://tools.ietf.org/html/rfc6749#section-4.1 " > Section 4 . 1 Authorization Code Grant < / a >
* @see < a target = " _blank " href = " https://tools.ietf.org/html/rfc6749#section-4.1.1 " > Section 4 . 1 . 1 Authorization Request < / a >
2020-09-22 22:57:50 +07:00
* @see < a target = " _blank " href = " https://tools.ietf.org/html/rfc6749#section-4.1.2 " > Section 4 . 1 . 2 Authorization Response < / a >
2020-04-24 03:45:34 +07:00
* /
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
2020-05-20 03:55:50 +07:00
/ * *
* The default endpoint { @code URI } for authorization requests .
* /
public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = " /oauth2/authorize " ;
2020-10-01 02:20:38 +07:00
private static final String PKCE_ERROR_URI = " https://tools.ietf.org/html/rfc7636#section-4.4.1 " ;
2020-05-20 03:55:50 +07:00
private final RegisteredClientRepository registeredClientRepository ;
private final OAuth2AuthorizationService authorizationService ;
2020-09-22 22:57:50 +07:00
private final RequestMatcher authorizationRequestMatcher ;
private final RequestMatcher userConsentMatcher ;
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator ( Base64 . getUrlEncoder ( ) . withoutPadding ( ) , 96 ) ;
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator ( Base64 . getUrlEncoder ( ) ) ;
2020-05-20 03:55:50 +07:00
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy ( ) ;
/ * *
* Constructs an { @code OAuth2AuthorizationEndpointFilter } using the provided parameters .
*
* @param registeredClientRepository the repository of registered clients
* @param authorizationService the authorization service
* /
2020-04-30 10:29:20 +07:00
public OAuth2AuthorizationEndpointFilter ( RegisteredClientRepository registeredClientRepository ,
OAuth2AuthorizationService authorizationService ) {
2020-05-20 03:55:50 +07:00
this ( registeredClientRepository , authorizationService , DEFAULT_AUTHORIZATION_ENDPOINT_URI ) ;
2020-04-30 10:29:20 +07:00
}
2020-05-20 03:55:50 +07:00
/ * *
* Constructs an { @code OAuth2AuthorizationEndpointFilter } using the provided parameters .
*
* @param registeredClientRepository the repository of registered clients
* @param authorizationService the authorization service
* @param authorizationEndpointUri the endpoint { @code URI } for authorization requests
* /
public OAuth2AuthorizationEndpointFilter ( RegisteredClientRepository registeredClientRepository ,
OAuth2AuthorizationService authorizationService , String authorizationEndpointUri ) {
Assert . notNull ( registeredClientRepository , " registeredClientRepository cannot be null " ) ;
Assert . notNull ( authorizationService , " authorizationService cannot be null " ) ;
Assert . hasText ( authorizationEndpointUri , " authorizationEndpointUri cannot be empty " ) ;
this . registeredClientRepository = registeredClientRepository ;
this . authorizationService = authorizationService ;
2020-12-04 22:02:07 +07:00
RequestMatcher authorizationRequestGetMatcher = new AntPathRequestMatcher (
2020-05-20 03:55:50 +07:00
authorizationEndpointUri , HttpMethod . GET . name ( ) ) ;
2020-12-04 22:02:07 +07:00
RequestMatcher authorizationRequestPostMatcher = new AntPathRequestMatcher (
2020-09-22 22:57:50 +07:00
authorizationEndpointUri , HttpMethod . POST . name ( ) ) ;
2020-12-04 22:02:07 +07:00
RequestMatcher openidScopeMatcher = request - > {
String scope = request . getParameter ( OAuth2ParameterNames . SCOPE ) ;
return StringUtils . hasText ( scope ) & & scope . contains ( OidcScopes . OPENID ) ;
} ;
RequestMatcher consentActionMatcher = request - >
request . getParameter ( UserConsentPage . CONSENT_ACTION_PARAMETER_NAME ) ! = null ;
this . authorizationRequestMatcher = new OrRequestMatcher (
authorizationRequestGetMatcher ,
new AndRequestMatcher (
authorizationRequestPostMatcher , openidScopeMatcher ,
new NegatedRequestMatcher ( consentActionMatcher ) ) ) ;
this . userConsentMatcher = new AndRequestMatcher (
authorizationRequestPostMatcher , consentActionMatcher ) ;
2020-04-30 10:29:20 +07:00
}
2020-04-24 03:45:34 +07:00
@Override
2020-05-20 03:55:50 +07:00
protected void doFilterInternal ( HttpServletRequest request , HttpServletResponse response , FilterChain filterChain )
2020-04-24 03:45:34 +07:00
throws ServletException , IOException {
2020-09-22 22:57:50 +07:00
if ( this . authorizationRequestMatcher . matches ( request ) ) {
processAuthorizationRequest ( request , response , filterChain ) ;
} else if ( this . userConsentMatcher . matches ( request ) ) {
processUserConsent ( request , response ) ;
} else {
2020-05-20 03:55:50 +07:00
filterChain . doFilter ( request , response ) ;
2020-09-22 22:57:50 +07:00
}
}
private void processAuthorizationRequest ( HttpServletRequest request , HttpServletResponse response , FilterChain filterChain )
throws ServletException , IOException {
OAuth2AuthorizationRequestContext authorizationRequestContext =
new OAuth2AuthorizationRequestContext (
request . getRequestURL ( ) . toString ( ) ,
OAuth2EndpointUtils . getParameters ( request ) ) ;
validateAuthorizationRequest ( authorizationRequestContext ) ;
if ( authorizationRequestContext . hasError ( ) ) {
if ( authorizationRequestContext . isRedirectOnError ( ) ) {
sendErrorResponse ( request , response , authorizationRequestContext . resolveRedirectUri ( ) ,
authorizationRequestContext . getError ( ) , authorizationRequestContext . getState ( ) ) ;
} else {
sendErrorResponse ( response , authorizationRequestContext . getError ( ) ) ;
}
2020-05-20 03:55:50 +07:00
return ;
}
2020-04-30 10:29:20 +07:00
2020-05-24 17:07:34 +07:00
// ---------------
2020-09-22 22:57:50 +07:00
// The request is valid - ensure the resource owner is authenticated
2020-05-24 17:07:34 +07:00
// ---------------
2020-05-20 03:55:50 +07:00
2020-09-22 22:57:50 +07:00
Authentication principal = SecurityContextHolder . getContext ( ) . getAuthentication ( ) ;
if ( ! isPrincipalAuthenticated ( principal ) ) {
// Pass through the chain with the expectation that the authentication process
// will commence via AuthenticationEntryPoint
filterChain . doFilter ( request , response ) ;
return ;
}
RegisteredClient registeredClient = authorizationRequestContext . getRegisteredClient ( ) ;
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext . buildAuthorizationRequest ( ) ;
OAuth2Authorization . Builder builder = OAuth2Authorization . withRegisteredClient ( registeredClient )
. principalName ( principal . getName ( ) )
2021-02-09 02:57:15 +07:00
. authorizationGrantType ( AuthorizationGrantType . AUTHORIZATION_CODE )
2021-01-18 21:31:06 +07:00
. attribute ( OAuth2AuthorizationAttributeNames . PRINCIPAL , principal )
2020-09-22 22:57:50 +07:00
. attribute ( OAuth2AuthorizationAttributeNames . AUTHORIZATION_REQUEST , authorizationRequest ) ;
if ( registeredClient . getClientSettings ( ) . requireUserConsent ( ) ) {
String state = this . stateGenerator . generateKey ( ) ;
OAuth2Authorization authorization = builder
. attribute ( OAuth2AuthorizationAttributeNames . STATE , state )
. build ( ) ;
this . authorizationService . save ( authorization ) ;
// TODO Need to remove 'in-flight' authorization if consent step is not completed (e.g. approved or cancelled)
UserConsentPage . displayConsent ( request , response , registeredClient , authorization ) ;
} else {
2020-10-23 01:03:24 +07:00
Instant issuedAt = Instant . now ( ) ;
Instant expiresAt = issuedAt . plus ( 5 , ChronoUnit . MINUTES ) ; // TODO Allow configuration for authorization code time-to-live
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode (
this . codeGenerator . generateKey ( ) , issuedAt , expiresAt ) ;
2020-09-22 22:57:50 +07:00
OAuth2Authorization authorization = builder
2021-02-06 01:20:17 +07:00
. token ( authorizationCode )
2020-09-22 22:57:50 +07:00
. attribute ( OAuth2AuthorizationAttributeNames . AUTHORIZED_SCOPES , authorizationRequest . getScopes ( ) )
. build ( ) ;
this . authorizationService . save ( authorization ) ;
// TODO security checks for code parameter
// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks.
// A maximum authorization code lifetime of 10 minutes is RECOMMENDED.
// The client MUST NOT use the authorization code more than once.
// If an authorization code is used more than once, the authorization server MUST deny the request
// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code.
// The authorization code is bound to the client identifier and redirection URI.
sendAuthorizationResponse ( request , response ,
2020-10-23 01:03:24 +07:00
authorizationRequestContext . resolveRedirectUri ( ) , authorizationCode , authorizationRequest . getState ( ) ) ;
2020-09-22 22:57:50 +07:00
}
}
private void processUserConsent ( HttpServletRequest request , HttpServletResponse response )
throws IOException {
UserConsentRequestContext userConsentRequestContext =
new UserConsentRequestContext (
request . getRequestURL ( ) . toString ( ) ,
OAuth2EndpointUtils . getParameters ( request ) ) ;
validateUserConsentRequest ( userConsentRequestContext ) ;
if ( userConsentRequestContext . hasError ( ) ) {
if ( userConsentRequestContext . isRedirectOnError ( ) ) {
sendErrorResponse ( request , response , userConsentRequestContext . resolveRedirectUri ( ) ,
userConsentRequestContext . getError ( ) , userConsentRequestContext . getState ( ) ) ;
} else {
sendErrorResponse ( response , userConsentRequestContext . getError ( ) ) ;
}
return ;
}
if ( ! UserConsentPage . isConsentApproved ( request ) ) {
this . authorizationService . remove ( userConsentRequestContext . getAuthorization ( ) ) ;
OAuth2Error error = createError ( OAuth2ErrorCodes . ACCESS_DENIED , OAuth2ParameterNames . CLIENT_ID ) ;
sendErrorResponse ( request , response , userConsentRequestContext . resolveRedirectUri ( ) ,
error , userConsentRequestContext . getAuthorizationRequest ( ) . getState ( ) ) ;
return ;
}
2020-10-23 01:03:24 +07:00
Instant issuedAt = Instant . now ( ) ;
Instant expiresAt = issuedAt . plus ( 5 , ChronoUnit . MINUTES ) ; // TODO Allow configuration for authorization code time-to-live
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode (
this . codeGenerator . generateKey ( ) , issuedAt , expiresAt ) ;
2020-09-22 22:57:50 +07:00
OAuth2Authorization authorization = OAuth2Authorization . from ( userConsentRequestContext . getAuthorization ( ) )
2021-02-06 01:20:17 +07:00
. token ( authorizationCode )
2020-09-22 22:57:50 +07:00
. attributes ( attrs - > {
attrs . remove ( OAuth2AuthorizationAttributeNames . STATE ) ;
attrs . put ( OAuth2AuthorizationAttributeNames . AUTHORIZED_SCOPES , userConsentRequestContext . getScopes ( ) ) ;
} )
. build ( ) ;
this . authorizationService . save ( authorization ) ;
sendAuthorizationResponse ( request , response , userConsentRequestContext . resolveRedirectUri ( ) ,
2020-10-23 01:03:24 +07:00
authorizationCode , userConsentRequestContext . getAuthorizationRequest ( ) . getState ( ) ) ;
2020-09-22 22:57:50 +07:00
}
private void validateAuthorizationRequest ( OAuth2AuthorizationRequestContext authorizationRequestContext ) {
// ---------------
// Validate the request to ensure all required parameters are present and valid
// ---------------
2020-05-20 03:55:50 +07:00
// client_id (REQUIRED)
2020-09-22 22:57:50 +07:00
if ( ! StringUtils . hasText ( authorizationRequestContext . getClientId ( ) ) | |
authorizationRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . CLIENT_ID ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
2020-05-20 03:55:50 +07:00
return ;
}
2020-09-22 22:57:50 +07:00
RegisteredClient registeredClient = this . registeredClientRepository . findByClientId (
authorizationRequestContext . getClientId ( ) ) ;
2020-05-20 03:55:50 +07:00
if ( registeredClient = = null ) {
2020-09-22 22:57:50 +07:00
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
2020-05-20 03:55:50 +07:00
return ;
} else if ( ! registeredClient . getAuthorizationGrantTypes ( ) . contains ( AuthorizationGrantType . AUTHORIZATION_CODE ) ) {
2020-09-22 22:57:50 +07:00
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . UNAUTHORIZED_CLIENT , OAuth2ParameterNames . CLIENT_ID ) ) ;
2020-05-20 03:55:50 +07:00
return ;
2020-04-30 10:29:20 +07:00
}
2020-09-22 22:57:50 +07:00
authorizationRequestContext . setRegisteredClient ( registeredClient ) ;
2020-04-30 10:29:20 +07:00
2020-05-20 03:55:50 +07:00
// redirect_uri (OPTIONAL)
2020-09-22 22:57:50 +07:00
if ( StringUtils . hasText ( authorizationRequestContext . getRedirectUri ( ) ) ) {
if ( ! registeredClient . getRedirectUris ( ) . contains ( authorizationRequestContext . getRedirectUri ( ) ) | |
authorizationRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . REDIRECT_URI ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . REDIRECT_URI ) ) ;
2020-05-20 03:55:50 +07:00
return ;
2020-04-30 10:29:20 +07:00
}
2020-12-04 22:02:07 +07:00
} else if ( authorizationRequestContext . isAuthenticationRequest ( ) | | // redirect_uri is REQUIRED for OpenID Connect
registeredClient . getRedirectUris ( ) . size ( ) ! = 1 ) {
2020-09-22 22:57:50 +07:00
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . REDIRECT_URI ) ) ;
2020-05-20 03:55:50 +07:00
return ;
2020-04-30 10:29:20 +07:00
}
2020-09-22 22:57:50 +07:00
authorizationRequestContext . setRedirectOnError ( true ) ;
2020-05-20 03:55:50 +07:00
// response_type (REQUIRED)
2020-09-22 22:57:50 +07:00
if ( ! StringUtils . hasText ( authorizationRequestContext . getResponseType ( ) ) | |
authorizationRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . RESPONSE_TYPE ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . RESPONSE_TYPE ) ) ;
2020-05-20 03:55:50 +07:00
return ;
2020-09-22 22:57:50 +07:00
} else if ( ! authorizationRequestContext . getResponseType ( ) . equals ( OAuth2AuthorizationResponseType . CODE . getValue ( ) ) ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . UNSUPPORTED_RESPONSE_TYPE , OAuth2ParameterNames . RESPONSE_TYPE ) ) ;
return ;
}
// scope (OPTIONAL)
Set < String > requestedScopes = authorizationRequestContext . getScopes ( ) ;
Set < String > allowedScopes = registeredClient . getScopes ( ) ;
if ( ! requestedScopes . isEmpty ( ) & & ! allowedScopes . containsAll ( requestedScopes ) ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_SCOPE , OAuth2ParameterNames . SCOPE ) ) ;
2020-05-20 03:55:50 +07:00
return ;
2020-04-30 10:29:20 +07:00
}
2020-04-24 03:45:34 +07:00
2020-06-23 02:35:01 +07:00
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
2020-09-22 22:57:50 +07:00
String codeChallenge = authorizationRequestContext . getParameters ( ) . getFirst ( PkceParameterNames . CODE_CHALLENGE ) ;
2020-06-23 02:35:01 +07:00
if ( StringUtils . hasText ( codeChallenge ) ) {
2020-09-22 22:57:50 +07:00
if ( authorizationRequestContext . getParameters ( ) . get ( PkceParameterNames . CODE_CHALLENGE ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE , PKCE_ERROR_URI ) ) ;
2020-06-23 02:35:01 +07:00
return ;
}
2020-09-22 22:57:50 +07:00
String codeChallengeMethod = authorizationRequestContext . getParameters ( ) . getFirst ( PkceParameterNames . CODE_CHALLENGE_METHOD ) ;
if ( StringUtils . hasText ( codeChallengeMethod ) ) {
if ( authorizationRequestContext . getParameters ( ) . get ( PkceParameterNames . CODE_CHALLENGE_METHOD ) . size ( ) ! = 1 | |
( ! " S256 " . equals ( codeChallengeMethod ) & & ! " plain " . equals ( codeChallengeMethod ) ) ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE_METHOD , PKCE_ERROR_URI ) ) ;
return ;
}
2020-06-23 02:35:01 +07:00
}
} else if ( registeredClient . getClientSettings ( ) . requireProofKey ( ) ) {
2020-09-22 22:57:50 +07:00
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE , PKCE_ERROR_URI ) ) ;
2020-06-23 02:35:01 +07:00
return ;
}
2020-09-22 22:57:50 +07:00
}
2020-06-23 02:35:01 +07:00
2020-09-22 22:57:50 +07:00
private void validateUserConsentRequest ( UserConsentRequestContext userConsentRequestContext ) {
2020-05-24 17:07:34 +07:00
// ---------------
2020-09-22 22:57:50 +07:00
// Validate the request to ensure all required parameters are present and valid
2020-05-24 17:07:34 +07:00
// ---------------
2020-09-22 22:57:50 +07:00
// state (REQUIRED)
if ( ! StringUtils . hasText ( userConsentRequestContext . getState ( ) ) | |
userConsentRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . STATE ) . size ( ) ! = 1 ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . STATE ) ) ;
2020-05-24 17:07:34 +07:00
return ;
}
2020-09-22 22:57:50 +07:00
OAuth2Authorization authorization = this . authorizationService . findByToken (
userConsentRequestContext . getState ( ) , new TokenType ( OAuth2AuthorizationAttributeNames . STATE ) ) ;
if ( authorization = = null ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . STATE ) ) ;
return ;
}
userConsentRequestContext . setAuthorization ( authorization ) ;
2020-05-24 17:07:34 +07:00
2020-09-22 22:57:50 +07:00
// The 'in-flight' authorization must be associated to the current principal
Authentication principal = SecurityContextHolder . getContext ( ) . getAuthentication ( ) ;
if ( ! isPrincipalAuthenticated ( principal ) | | ! principal . getName ( ) . equals ( authorization . getPrincipalName ( ) ) ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . STATE ) ) ;
return ;
}
2020-04-30 10:29:20 +07:00
2020-09-22 22:57:50 +07:00
// client_id (REQUIRED)
if ( ! StringUtils . hasText ( userConsentRequestContext . getClientId ( ) ) | |
userConsentRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . CLIENT_ID ) . size ( ) ! = 1 ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
return ;
}
RegisteredClient registeredClient = this . registeredClientRepository . findByClientId (
userConsentRequestContext . getClientId ( ) ) ;
if ( registeredClient = = null | | ! registeredClient . getId ( ) . equals ( authorization . getRegisteredClientId ( ) ) ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
return ;
}
userConsentRequestContext . setRegisteredClient ( registeredClient ) ;
userConsentRequestContext . setRedirectOnError ( true ) ;
// scope (OPTIONAL)
Set < String > requestedScopes = userConsentRequestContext . getAuthorizationRequest ( ) . getScopes ( ) ;
Set < String > authorizedScopes = userConsentRequestContext . getScopes ( ) ;
if ( ! authorizedScopes . isEmpty ( ) & & ! requestedScopes . containsAll ( authorizedScopes ) ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_SCOPE , OAuth2ParameterNames . SCOPE ) ) ;
return ;
}
2020-04-30 10:29:20 +07:00
}
2020-05-20 03:55:50 +07:00
private void sendAuthorizationResponse ( HttpServletRequest request , HttpServletResponse response ,
2020-10-23 01:03:24 +07:00
String redirectUri , OAuth2AuthorizationCode authorizationCode , String state ) throws IOException {
2020-04-30 10:29:20 +07:00
2020-05-20 03:55:50 +07:00
UriComponentsBuilder uriBuilder = UriComponentsBuilder
. fromUriString ( redirectUri )
2020-10-23 01:03:24 +07:00
. queryParam ( OAuth2ParameterNames . CODE , authorizationCode . getTokenValue ( ) ) ;
2020-09-22 22:57:50 +07:00
if ( StringUtils . hasText ( state ) ) {
uriBuilder . queryParam ( OAuth2ParameterNames . STATE , state ) ;
2020-05-20 03:55:50 +07:00
}
this . redirectStrategy . sendRedirect ( request , response , uriBuilder . toUriString ( ) ) ;
2020-04-30 10:29:20 +07:00
}
2020-05-20 03:55:50 +07:00
private void sendErrorResponse ( HttpServletRequest request , HttpServletResponse response ,
2020-09-22 22:57:50 +07:00
String redirectUri , OAuth2Error error , String state ) throws IOException {
2020-05-20 03:55:50 +07:00
UriComponentsBuilder uriBuilder = UriComponentsBuilder
. fromUriString ( redirectUri )
. queryParam ( OAuth2ParameterNames . ERROR , error . getErrorCode ( ) ) ;
if ( StringUtils . hasText ( error . getDescription ( ) ) ) {
uriBuilder . queryParam ( OAuth2ParameterNames . ERROR_DESCRIPTION , error . getDescription ( ) ) ;
2020-04-30 10:29:20 +07:00
}
2020-05-20 03:55:50 +07:00
if ( StringUtils . hasText ( error . getUri ( ) ) ) {
uriBuilder . queryParam ( OAuth2ParameterNames . ERROR_URI , error . getUri ( ) ) ;
}
if ( StringUtils . hasText ( state ) ) {
uriBuilder . queryParam ( OAuth2ParameterNames . STATE , state ) ;
}
this . redirectStrategy . sendRedirect ( request , response , uriBuilder . toUriString ( ) ) ;
2020-04-30 10:29:20 +07:00
}
2020-09-22 22:57:50 +07:00
private void sendErrorResponse ( HttpServletResponse response , OAuth2Error error ) throws IOException {
// TODO Send default html error response
response . sendError ( HttpStatus . BAD_REQUEST . value ( ) , error . toString ( ) ) ;
}
2020-05-24 17:07:34 +07:00
private static OAuth2Error createError ( String errorCode , String parameterName ) {
2020-06-23 02:35:01 +07:00
return createError ( errorCode , parameterName , " https://tools.ietf.org/html/rfc6749#section-4.1.2.1 " ) ;
}
private static OAuth2Error createError ( String errorCode , String parameterName , String errorUri ) {
return new OAuth2Error ( errorCode , " OAuth 2.0 Parameter: " + parameterName , errorUri ) ;
2020-04-30 10:29:20 +07:00
}
2020-05-20 03:55:50 +07:00
private static boolean isPrincipalAuthenticated ( Authentication principal ) {
return principal ! = null & &
! AnonymousAuthenticationToken . class . isAssignableFrom ( principal . getClass ( ) ) & &
principal . isAuthenticated ( ) ;
2020-04-30 10:29:20 +07:00
}
2020-09-22 22:57:50 +07:00
private static class OAuth2AuthorizationRequestContext extends AbstractRequestContext {
private final String responseType ;
private final String redirectUri ;
private OAuth2AuthorizationRequestContext (
String authorizationUri , MultiValueMap < String , String > parameters ) {
super ( authorizationUri , parameters ,
parameters . getFirst ( OAuth2ParameterNames . CLIENT_ID ) ,
parameters . getFirst ( OAuth2ParameterNames . STATE ) ,
extractScopes ( parameters ) ) ;
this . responseType = parameters . getFirst ( OAuth2ParameterNames . RESPONSE_TYPE ) ;
this . redirectUri = parameters . getFirst ( OAuth2ParameterNames . REDIRECT_URI ) ;
}
2020-04-30 10:29:20 +07:00
2020-09-22 22:57:50 +07:00
private static Set < String > extractScopes ( MultiValueMap < String , String > parameters ) {
2020-05-20 03:55:50 +07:00
String scope = parameters . getFirst ( OAuth2ParameterNames . SCOPE ) ;
2020-09-22 22:57:50 +07:00
return StringUtils . hasText ( scope ) ?
new HashSet < > ( Arrays . asList ( StringUtils . delimitedListToStringArray ( scope , " " ) ) ) :
Collections . emptySet ( ) ;
}
private String getResponseType ( ) {
return this . responseType ;
}
private String getRedirectUri ( ) {
return this . redirectUri ;
}
2020-12-04 22:02:07 +07:00
private boolean isAuthenticationRequest ( ) {
return getScopes ( ) . contains ( OidcScopes . OPENID ) ;
}
2020-09-22 22:57:50 +07:00
protected String resolveRedirectUri ( ) {
return StringUtils . hasText ( getRedirectUri ( ) ) ?
getRedirectUri ( ) :
getRegisteredClient ( ) . getRedirectUris ( ) . iterator ( ) . next ( ) ;
}
private OAuth2AuthorizationRequest buildAuthorizationRequest ( ) {
return OAuth2AuthorizationRequest . authorizationCode ( )
. authorizationUri ( getAuthorizationUri ( ) )
. clientId ( getClientId ( ) )
. redirectUri ( getRedirectUri ( ) )
. scopes ( getScopes ( ) )
. state ( getState ( ) )
. additionalParameters ( additionalParameters - >
getParameters ( ) . entrySet ( ) . stream ( )
. filter ( e - > ! e . getKey ( ) . equals ( OAuth2ParameterNames . RESPONSE_TYPE ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . CLIENT_ID ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . REDIRECT_URI ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . SCOPE ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . STATE ) )
. forEach ( e - > additionalParameters . put ( e . getKey ( ) , e . getValue ( ) . get ( 0 ) ) ) )
. build ( ) ;
}
}
private static class UserConsentRequestContext extends AbstractRequestContext {
private OAuth2Authorization authorization ;
private UserConsentRequestContext (
String authorizationUri , MultiValueMap < String , String > parameters ) {
super ( authorizationUri , parameters ,
parameters . getFirst ( OAuth2ParameterNames . CLIENT_ID ) ,
parameters . getFirst ( OAuth2ParameterNames . STATE ) ,
extractScopes ( parameters ) ) ;
}
private static Set < String > extractScopes ( MultiValueMap < String , String > parameters ) {
List < String > scope = parameters . get ( OAuth2ParameterNames . SCOPE ) ;
return ! CollectionUtils . isEmpty ( scope ) ? new HashSet < > ( scope ) : Collections . emptySet ( ) ;
}
private OAuth2Authorization getAuthorization ( ) {
return this . authorization ;
}
private void setAuthorization ( OAuth2Authorization authorization ) {
this . authorization = authorization ;
}
protected String resolveRedirectUri ( ) {
OAuth2AuthorizationRequest authorizationRequest = getAuthorizationRequest ( ) ;
return StringUtils . hasText ( authorizationRequest . getRedirectUri ( ) ) ?
authorizationRequest . getRedirectUri ( ) :
getRegisteredClient ( ) . getRedirectUris ( ) . iterator ( ) . next ( ) ;
}
private OAuth2AuthorizationRequest getAuthorizationRequest ( ) {
return getAuthorization ( ) . getAttribute ( OAuth2AuthorizationAttributeNames . AUTHORIZATION_REQUEST ) ;
}
}
private abstract static class AbstractRequestContext {
private final String authorizationUri ;
private final MultiValueMap < String , String > parameters ;
private final String clientId ;
private final String state ;
private final Set < String > scopes ;
private RegisteredClient registeredClient ;
private OAuth2Error error ;
private boolean redirectOnError ;
protected AbstractRequestContext ( String authorizationUri , MultiValueMap < String , String > parameters ,
String clientId , String state , Set < String > scopes ) {
this . authorizationUri = authorizationUri ;
this . parameters = parameters ;
this . clientId = clientId ;
this . state = state ;
this . scopes = scopes ;
}
protected String getAuthorizationUri ( ) {
return this . authorizationUri ;
}
protected MultiValueMap < String , String > getParameters ( ) {
return this . parameters ;
}
protected String getClientId ( ) {
return this . clientId ;
}
protected String getState ( ) {
return this . state ;
}
protected Set < String > getScopes ( ) {
return this . scopes ;
}
protected RegisteredClient getRegisteredClient ( ) {
return this . registeredClient ;
}
protected void setRegisteredClient ( RegisteredClient registeredClient ) {
this . registeredClient = registeredClient ;
}
protected OAuth2Error getError ( ) {
return this . error ;
}
protected void setError ( OAuth2Error error ) {
this . error = error ;
}
protected boolean hasError ( ) {
return getError ( ) ! = null ;
}
protected boolean isRedirectOnError ( ) {
return this . redirectOnError ;
}
protected void setRedirectOnError ( boolean redirectOnError ) {
this . redirectOnError = redirectOnError ;
}
protected abstract String resolveRedirectUri ( ) ;
}
private static class UserConsentPage {
private static final MediaType TEXT_HTML_UTF8 = new MediaType ( " text " , " html " , StandardCharsets . UTF_8 ) ;
private static final String CONSENT_ACTION_PARAMETER_NAME = " consent_action " ;
private static final String CONSENT_ACTION_APPROVE = " approve " ;
private static final String CONSENT_ACTION_CANCEL = " cancel " ;
private static void displayConsent ( HttpServletRequest request , HttpServletResponse response ,
RegisteredClient registeredClient , OAuth2Authorization authorization ) throws IOException {
String consentPage = generateConsentPage ( request , registeredClient , authorization ) ;
response . setContentType ( TEXT_HTML_UTF8 . toString ( ) ) ;
response . setContentLength ( consentPage . getBytes ( StandardCharsets . UTF_8 ) . length ) ;
response . getWriter ( ) . write ( consentPage ) ;
}
private static boolean isConsentApproved ( HttpServletRequest request ) {
return CONSENT_ACTION_APPROVE . equalsIgnoreCase ( request . getParameter ( CONSENT_ACTION_PARAMETER_NAME ) ) ;
}
private static boolean isConsentCancelled ( HttpServletRequest request ) {
return CONSENT_ACTION_CANCEL . equalsIgnoreCase ( request . getParameter ( CONSENT_ACTION_PARAMETER_NAME ) ) ;
}
private static String generateConsentPage ( HttpServletRequest request ,
RegisteredClient registeredClient , OAuth2Authorization authorization ) {
OAuth2AuthorizationRequest authorizationRequest = authorization . getAttribute (
OAuth2AuthorizationAttributeNames . AUTHORIZATION_REQUEST ) ;
String state = authorization . getAttribute (
OAuth2AuthorizationAttributeNames . STATE ) ;
StringBuilder builder = new StringBuilder ( ) ;
builder . append ( " <!DOCTYPE html> " ) ;
builder . append ( " <html lang= \" en \" > " ) ;
builder . append ( " <head> " ) ;
builder . append ( " <meta charset= \" utf-8 \" > " ) ;
builder . append ( " <meta name= \" viewport \" content= \" width=device-width, initial-scale=1, shrink-to-fit=no \" > " ) ;
builder . append ( " <link rel= \" stylesheet \" href= \" https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css \" integrity= \" sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z \" crossorigin= \" anonymous \" > " ) ;
builder . append ( " <title>Consent required</title> " ) ;
builder . append ( " </head> " ) ;
builder . append ( " <body> " ) ;
builder . append ( " <div class= \" container \" > " ) ;
builder . append ( " <div class= \" py-5 \" > " ) ;
builder . append ( " <h1 class= \" text-center \" >Consent required</h1> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " <div class= \" row \" > " ) ;
builder . append ( " <div class= \" col text-center \" > " ) ;
builder . append ( " <p><span class= \" font-weight-bold text-primary \" > " + registeredClient . getClientId ( ) + " </span> wants to access your account <span class= \" font-weight-bold \" > " + authorization . getPrincipalName ( ) + " </span></p> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " <div class= \" row pb-3 \" > " ) ;
builder . append ( " <div class= \" col text-center \" > " ) ;
builder . append ( " <p>The following permissions are requested by the above app.<br/>Please review these and consent if you approve.</p> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " <div class= \" row \" > " ) ;
builder . append ( " <div class= \" col text-center \" > " ) ;
builder . append ( " <form method= \" post \" action= \" " + request . getRequestURI ( ) + " \" > " ) ;
builder . append ( " <input type= \" hidden \" name= \" client_id \" value= \" " + registeredClient . getClientId ( ) + " \" > " ) ;
builder . append ( " <input type= \" hidden \" name= \" state \" value= \" " + state + " \" > " ) ;
for ( String scope : authorizationRequest . getScopes ( ) ) {
builder . append ( " <div class= \" form-group form-check py-1 \" > " ) ;
builder . append ( " <input class= \" form-check-input \" type= \" checkbox \" name= \" scope \" value= \" " + scope + " \" id= \" " + scope + " \" checked> " ) ;
builder . append ( " <label class= \" form-check-label \" for= \" " + scope + " \" > " + scope + " </label> " ) ;
builder . append ( " </div> " ) ;
}
builder . append ( " <div class= \" form-group pt-3 \" > " ) ;
builder . append ( " <button class= \" btn btn-primary btn-lg \" type= \" submit \" name= \" consent_action \" value= \" approve \" >Submit Consent</button> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " <div class= \" form-group \" > " ) ;
builder . append ( " <button class= \" btn btn-link regular \" type= \" submit \" name= \" consent_action \" value= \" cancel \" >Cancel</button> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " </form> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " <div class= \" row pt-4 \" > " ) ;
builder . append ( " <div class= \" col text-center \" > " ) ;
builder . append ( " <p><small>Your consent to provide access is required.<br/>If you do not approve, click Cancel, in which case no information will be shared with the app.</small></p> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " </div> " ) ;
builder . append ( " </body> " ) ;
builder . append ( " </html> " ) ;
return builder . toString ( ) ;
}
2020-05-20 03:55:50 +07:00
}
2020-04-24 03:45:34 +07:00
}