Polish gh-79
This commit is contained in:
parent
8dbdd6640c
commit
4c8f89af5c
@ -33,7 +33,6 @@ import java.util.function.Consumer;
|
|||||||
*
|
*
|
||||||
* @author Joe Grandja
|
* @author Joe Grandja
|
||||||
* @author Krisztian Toth
|
* @author Krisztian Toth
|
||||||
* @author Madhu Bhat
|
|
||||||
* @since 0.0.1
|
* @since 0.0.1
|
||||||
* @see RegisteredClient
|
* @see RegisteredClient
|
||||||
* @see OAuth2AccessToken
|
* @see OAuth2AccessToken
|
||||||
@ -75,15 +74,6 @@ public class OAuth2Authorization implements Serializable {
|
|||||||
return this.accessToken;
|
return this.accessToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the access token {@link OAuth2AccessToken} in the {@link OAuth2Authorization}.
|
|
||||||
*
|
|
||||||
* @param accessToken the access token
|
|
||||||
*/
|
|
||||||
public final void setAccessToken(OAuth2AccessToken accessToken) {
|
|
||||||
this.accessToken = accessToken;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the attribute(s) associated to the authorization.
|
* Returns the attribute(s) associated to the authorization.
|
||||||
*
|
*
|
||||||
|
@ -17,8 +17,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
|
|||||||
|
|
||||||
import org.springframework.security.authentication.AbstractAuthenticationToken;
|
import org.springframework.security.authentication.AbstractAuthenticationToken;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
import org.springframework.security.core.SpringSecurityCoreVersion;
|
|
||||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.Version;
|
||||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@ -28,7 +28,7 @@ import java.util.Collections;
|
|||||||
* @author Madhu Bhat
|
* @author Madhu Bhat
|
||||||
*/
|
*/
|
||||||
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
|
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
|
||||||
private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
|
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
|
||||||
private RegisteredClient registeredClient;
|
private RegisteredClient registeredClient;
|
||||||
private Authentication clientPrincipal;
|
private Authentication clientPrincipal;
|
||||||
private OAuth2AccessToken accessToken;
|
private OAuth2AccessToken accessToken;
|
||||||
@ -52,9 +52,9 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the access token {@link OAuth2AccessToken}.
|
* Returns the {@link OAuth2AccessToken access token}.
|
||||||
*
|
*
|
||||||
* @return the access token
|
* @return the {@link OAuth2AccessToken}
|
||||||
*/
|
*/
|
||||||
public OAuth2AccessToken getAccessToken() {
|
public OAuth2AccessToken getAccessToken() {
|
||||||
return this.accessToken;
|
return this.accessToken;
|
||||||
|
@ -18,7 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
|
|||||||
import org.springframework.lang.Nullable;
|
import org.springframework.lang.Nullable;
|
||||||
import org.springframework.security.authentication.AbstractAuthenticationToken;
|
import org.springframework.security.authentication.AbstractAuthenticationToken;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
import org.springframework.security.core.SpringSecurityCoreVersion;
|
import org.springframework.security.oauth2.server.authorization.Version;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ import java.util.Collections;
|
|||||||
* @author Madhu Bhat
|
* @author Madhu Bhat
|
||||||
*/
|
*/
|
||||||
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
|
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
|
||||||
private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
|
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
|
||||||
private String code;
|
private String code;
|
||||||
private Authentication clientPrincipal;
|
private Authentication clientPrincipal;
|
||||||
private String clientId;
|
private String clientId;
|
||||||
@ -37,26 +37,26 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
|
|||||||
Authentication clientPrincipal, @Nullable String redirectUri) {
|
Authentication clientPrincipal, @Nullable String redirectUri) {
|
||||||
super(Collections.emptyList());
|
super(Collections.emptyList());
|
||||||
this.code = code;
|
this.code = code;
|
||||||
this.redirectUri = redirectUri;
|
|
||||||
this.clientPrincipal = clientPrincipal;
|
this.clientPrincipal = clientPrincipal;
|
||||||
|
this.redirectUri = redirectUri;
|
||||||
}
|
}
|
||||||
|
|
||||||
public OAuth2AuthorizationCodeAuthenticationToken(String code,
|
public OAuth2AuthorizationCodeAuthenticationToken(String code,
|
||||||
String clientId, @Nullable String redirectUri) {
|
String clientId, @Nullable String redirectUri) {
|
||||||
super(Collections.emptyList());
|
super(Collections.emptyList());
|
||||||
this.code = code;
|
this.code = code;
|
||||||
this.redirectUri = redirectUri;
|
|
||||||
this.clientId = clientId;
|
this.clientId = clientId;
|
||||||
}
|
this.redirectUri = redirectUri;
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object getCredentials() {
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getPrincipal() {
|
public Object getPrincipal() {
|
||||||
return null;
|
return this.clientPrincipal != null ? this.clientPrincipal : this.clientId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getCredentials() {
|
||||||
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -67,4 +67,13 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
|
|||||||
public String getCode() {
|
public String getCode() {
|
||||||
return this.code;
|
return this.code;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the redirectUri.
|
||||||
|
*
|
||||||
|
* @return the redirectUri
|
||||||
|
*/
|
||||||
|
public String getRedirectUri() {
|
||||||
|
return this.redirectUri;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,6 @@ import org.springframework.security.web.RedirectStrategy;
|
|||||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.util.LinkedMultiValueMap;
|
|
||||||
import org.springframework.util.MultiValueMap;
|
import org.springframework.util.MultiValueMap;
|
||||||
import org.springframework.util.StringUtils;
|
import org.springframework.util.StringUtils;
|
||||||
import org.springframework.web.filter.OncePerRequestFilter;
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
@ -53,7 +52,6 @@ import java.util.Arrays;
|
|||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -123,7 +121,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
|
|||||||
// Validate the request to ensure that all required parameters are present and valid
|
// Validate the request to ensure that all required parameters are present and valid
|
||||||
// ---------------
|
// ---------------
|
||||||
|
|
||||||
MultiValueMap<String, String> parameters = getParameters(request);
|
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
|
||||||
String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
|
String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
|
||||||
|
|
||||||
// client_id (REQUIRED)
|
// client_id (REQUIRED)
|
||||||
@ -258,7 +256,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
|
private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
|
||||||
MultiValueMap<String, String> parameters = getParameters(request);
|
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
|
||||||
|
|
||||||
Set<String> scopes = Collections.emptySet();
|
Set<String> scopes = Collections.emptySet();
|
||||||
if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) {
|
if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) {
|
||||||
@ -282,17 +280,4 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
|
|||||||
.forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0))))
|
.forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0))))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
|
|
||||||
Map<String, String[]> parameterMap = request.getParameterMap();
|
|
||||||
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
|
|
||||||
parameterMap.forEach((key, values) -> {
|
|
||||||
if (values.length > 0) {
|
|
||||||
for (String value : values) {
|
|
||||||
parameters.add(key, value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return parameters;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.springframework.security.oauth2.server.authorization.web;
|
||||||
|
|
||||||
|
import org.springframework.util.LinkedMultiValueMap;
|
||||||
|
import org.springframework.util.MultiValueMap;
|
||||||
|
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility methods for the OAuth 2.0 Protocol Endpoints.
|
||||||
|
*
|
||||||
|
* @author Joe Grandja
|
||||||
|
* @since 0.0.1
|
||||||
|
* @see OAuth2AuthorizationEndpointFilter
|
||||||
|
* @see OAuth2TokenEndpointFilter
|
||||||
|
*/
|
||||||
|
final class OAuth2EndpointUtils {
|
||||||
|
|
||||||
|
private OAuth2EndpointUtils() {
|
||||||
|
}
|
||||||
|
|
||||||
|
static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
|
||||||
|
Map<String, String[]> parameterMap = request.getParameterMap();
|
||||||
|
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
|
||||||
|
parameterMap.forEach((key, values) -> {
|
||||||
|
if (values.length > 0) {
|
||||||
|
for (String value : values) {
|
||||||
|
parameters.add(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return parameters;
|
||||||
|
}
|
||||||
|
}
|
@ -15,13 +15,11 @@
|
|||||||
*/
|
*/
|
||||||
package org.springframework.security.oauth2.server.authorization.web;
|
package org.springframework.security.oauth2.server.authorization.web;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import org.springframework.core.convert.converter.Converter;
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.http.HttpHeaders;
|
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.HttpStatus;
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.converter.HttpMessageConverter;
|
||||||
|
import org.springframework.http.server.ServletServerHttpResponse;
|
||||||
import org.springframework.security.authentication.AuthenticationManager;
|
import org.springframework.security.authentication.AuthenticationManager;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
import org.springframework.security.core.context.SecurityContextHolder;
|
import org.springframework.security.core.context.SecurityContextHolder;
|
||||||
@ -30,15 +28,17 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|||||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
|
||||||
|
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
|
||||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||||
import org.springframework.security.oauth2.server.authorization.TokenType;
|
|
||||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
|
||||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
|
||||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.MultiValueMap;
|
||||||
import org.springframework.util.StringUtils;
|
import org.springframework.util.StringUtils;
|
||||||
import org.springframework.web.filter.OncePerRequestFilter;
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
|
|
||||||
@ -47,145 +47,171 @@ import javax.servlet.ServletException;
|
|||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.Writer;
|
import java.time.temporal.ChronoUnit;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This {@code Filter} is used by the client to obtain an access token by presenting
|
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
|
||||||
* its authorization grant.
|
* which handles the processing of the OAuth 2.0 Access Token Request.
|
||||||
*
|
*
|
||||||
* <p>
|
* <p>
|
||||||
* It converts the OAuth 2.0 Access Token Request to {@link OAuth2AuthorizationCodeAuthenticationToken},
|
* It converts the OAuth 2.0 Access Token Request to an {@link OAuth2AuthorizationCodeAuthenticationToken},
|
||||||
* which is then authenticated by the {@link AuthenticationManager} and gets back
|
* which is then authenticated by the {@link AuthenticationManager}.
|
||||||
* {@link OAuth2AccessTokenAuthenticationToken} which has the {@link OAuth2AccessToken} if the request
|
* If the authentication succeeds, the {@link AuthenticationManager} returns an
|
||||||
* was successfully authenticated. The {@link OAuth2AccessToken} is then updated in the in-flight {@link OAuth2Authorization}
|
* {@link OAuth2AccessTokenAuthenticationToken}, which contains
|
||||||
* and sent back to the client. In case the authentication fails, an HTTP 401 (Unauthorized) response is returned.
|
* the {@link OAuth2AccessToken} that is returned in the response.
|
||||||
|
* In case of any error, an {@link OAuth2Error} is returned in the response.
|
||||||
*
|
*
|
||||||
* <p>
|
* <p>
|
||||||
* By default, this {@code Filter} responds to access token requests
|
* By default, this {@code Filter} responds to access token requests
|
||||||
* at the {@code URI} {@code /oauth2/token} and {@code HttpMethod} {@code POST}
|
* at the {@code URI} {@code /oauth2/token} and {@code HttpMethod} {@code POST}.
|
||||||
* using the default {@link AntPathRequestMatcher}.
|
|
||||||
*
|
*
|
||||||
* <p>
|
* <p>
|
||||||
* The default base {@code URI} {@code /oauth2/token} may be overridden
|
* The default endpoint {@code URI} {@code /oauth2/token} may be overridden
|
||||||
* via the constructor {@link #OAuth2TokenEndpointFilter(OAuth2AuthorizationService, AuthenticationManager, String)}.
|
* via the constructor {@link #OAuth2TokenEndpointFilter(AuthenticationManager, OAuth2AuthorizationService, String)}.
|
||||||
*
|
*
|
||||||
* @author Joe Grandja
|
* @author Joe Grandja
|
||||||
* @author Madhu Bhat
|
* @author Madhu Bhat
|
||||||
|
* @since 0.0.1
|
||||||
|
* @see AuthenticationManager
|
||||||
|
* @see OAuth2AuthorizationService
|
||||||
|
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
|
||||||
|
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
|
||||||
*/
|
*/
|
||||||
public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
|
public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
|
||||||
/**
|
/**
|
||||||
* The default endpoint {@code URI} for access token requests.
|
* The default endpoint {@code URI} for access token requests.
|
||||||
*/
|
*/
|
||||||
private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token";
|
public static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token";
|
||||||
|
|
||||||
private Converter<HttpServletRequest, Authentication> authorizationGrantConverter = this::convert;
|
private final AuthenticationManager authenticationManager;
|
||||||
private AuthenticationManager authenticationManager;
|
private final OAuth2AuthorizationService authorizationService;
|
||||||
private OAuth2AuthorizationService authorizationService;
|
private final RequestMatcher tokenEndpointMatcher;
|
||||||
private RequestMatcher uriMatcher;
|
private final Converter<HttpServletRequest, Authentication> authorizationGrantAuthenticationConverter =
|
||||||
private ObjectMapper objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
new AuthorizationCodeAuthenticationConverter();
|
||||||
|
private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
|
||||||
|
new OAuth2AccessTokenResponseHttpMessageConverter();
|
||||||
|
private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
|
||||||
|
new OAuth2ErrorHttpMessageConverter();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
|
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
|
||||||
*
|
*
|
||||||
* @param authorizationService the authorization service implementation
|
* @param authenticationManager the authentication manager
|
||||||
* @param authenticationManager the authentication manager implementation
|
* @param authorizationService the authorization service
|
||||||
*/
|
*/
|
||||||
public OAuth2TokenEndpointFilter(OAuth2AuthorizationService authorizationService, AuthenticationManager authenticationManager) {
|
public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager,
|
||||||
Assert.notNull(authorizationService, "authorizationService cannot be null");
|
OAuth2AuthorizationService authorizationService) {
|
||||||
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
|
this(authenticationManager, authorizationService, DEFAULT_TOKEN_ENDPOINT_URI);
|
||||||
this.authenticationManager = authenticationManager;
|
|
||||||
this.authorizationService = authorizationService;
|
|
||||||
this.uriMatcher = new AntPathRequestMatcher(DEFAULT_TOKEN_ENDPOINT_URI, HttpMethod.POST.name());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
|
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
|
||||||
*
|
*
|
||||||
* @param authorizationService the authorization service implementation
|
* @param authenticationManager the authentication manager
|
||||||
* @param authenticationManager the authentication manager implementation
|
* @param authorizationService the authorization service
|
||||||
* @param tokenEndpointUri the token endpoint's uri
|
* @param tokenEndpointUri the endpoint {@code URI} for access token requests
|
||||||
*/
|
*/
|
||||||
public OAuth2TokenEndpointFilter(OAuth2AuthorizationService authorizationService, AuthenticationManager authenticationManager,
|
public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager,
|
||||||
String tokenEndpointUri) {
|
OAuth2AuthorizationService authorizationService, String tokenEndpointUri) {
|
||||||
Assert.notNull(authorizationService, "authorizationService cannot be null");
|
|
||||||
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
|
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
|
||||||
|
Assert.notNull(authorizationService, "authorizationService cannot be null");
|
||||||
Assert.hasText(tokenEndpointUri, "tokenEndpointUri cannot be empty");
|
Assert.hasText(tokenEndpointUri, "tokenEndpointUri cannot be empty");
|
||||||
this.authenticationManager = authenticationManager;
|
this.authenticationManager = authenticationManager;
|
||||||
this.authorizationService = authorizationService;
|
this.authorizationService = authorizationService;
|
||||||
this.uriMatcher = new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name());
|
this.tokenEndpointMatcher = new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doFilterInternal(HttpServletRequest request,
|
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||||
HttpServletResponse response, FilterChain filterChain)
|
|
||||||
throws ServletException, IOException {
|
throws ServletException, IOException {
|
||||||
if (uriMatcher.matches(request)) {
|
|
||||||
try {
|
if (!this.tokenEndpointMatcher.matches(request)) {
|
||||||
if (validateAccessTokenRequest(request)) {
|
|
||||||
OAuth2AuthorizationCodeAuthenticationToken authCodeAuthToken =
|
|
||||||
(OAuth2AuthorizationCodeAuthenticationToken) authorizationGrantConverter.convert(request);
|
|
||||||
OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken =
|
|
||||||
(OAuth2AccessTokenAuthenticationToken) authenticationManager.authenticate(authCodeAuthToken);
|
|
||||||
if (accessTokenAuthenticationToken.isAuthenticated()) {
|
|
||||||
OAuth2Authorization authorization = authorizationService
|
|
||||||
.findByTokenAndTokenType(authCodeAuthToken.getCode(), TokenType.AUTHORIZATION_CODE);
|
|
||||||
authorization.setAccessToken(accessTokenAuthenticationToken.getAccessToken());
|
|
||||||
authorizationService.save(authorization);
|
|
||||||
writeSuccessResponse(response, accessTokenAuthenticationToken.getAccessToken());
|
|
||||||
} else {
|
|
||||||
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (OAuth2AuthenticationException exception) {
|
|
||||||
SecurityContextHolder.clearContext();
|
|
||||||
writeFailureResponse(response, exception.getError());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
filterChain.doFilter(request, response);
|
filterChain.doFilter(request, response);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
Authentication authorizationGrantAuthentication =
|
||||||
|
this.authorizationGrantAuthenticationConverter.convert(request);
|
||||||
|
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
|
||||||
|
(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication);
|
||||||
|
sendAccessTokenResponse(response, accessTokenAuthentication.getAccessToken());
|
||||||
|
} catch (OAuth2AuthenticationException ex) {
|
||||||
|
SecurityContextHolder.clearContext();
|
||||||
|
sendErrorResponse(response, ex.getError());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean validateAccessTokenRequest(HttpServletRequest request) {
|
private void sendAccessTokenResponse(HttpServletResponse response, OAuth2AccessToken accessToken) throws IOException {
|
||||||
if (StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.CODE))
|
OAuth2AccessTokenResponse.Builder builder =
|
||||||
|| StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
|
OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
|
||||||
|| StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.GRANT_TYPE))) {
|
.tokenType(accessToken.getTokenType())
|
||||||
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
|
.scopes(accessToken.getScopes());
|
||||||
} else if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(request.getParameter(OAuth2ParameterNames.GRANT_TYPE))) {
|
if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
|
||||||
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE));
|
builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
|
||||||
}
|
}
|
||||||
return true;
|
OAuth2AccessTokenResponse accessTokenResponse = builder.build();
|
||||||
|
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
|
||||||
|
this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
private OAuth2AuthorizationCodeAuthenticationToken convert(HttpServletRequest request) {
|
private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
|
||||||
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
|
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
|
||||||
return new OAuth2AuthorizationCodeAuthenticationToken(
|
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
|
||||||
request.getParameter(OAuth2ParameterNames.CODE),
|
this.errorHttpResponseConverter.write(error, null, httpResponse);
|
||||||
clientPrincipal,
|
|
||||||
request.getParameter(OAuth2ParameterNames.REDIRECT_URI)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeSuccessResponse(HttpServletResponse response, OAuth2AccessToken body) throws IOException {
|
private static OAuth2AuthenticationException throwError(String errorCode, String parameterName) {
|
||||||
try (Writer out = response.getWriter()) {
|
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
|
||||||
response.setStatus(HttpStatus.OK.value());
|
"https://tools.ietf.org/html/rfc6749#section-5.2");
|
||||||
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
|
throw new OAuth2AuthenticationException(error);
|
||||||
response.setCharacterEncoding("UTF-8");
|
|
||||||
response.setHeader(HttpHeaders.CACHE_CONTROL, "no-store");
|
|
||||||
response.setHeader(HttpHeaders.PRAGMA, "no-cache");
|
|
||||||
out.write(objectMapper.writeValueAsString(body));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeFailureResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
|
private static class AuthorizationCodeAuthenticationConverter implements Converter<HttpServletRequest, Authentication> {
|
||||||
try (Writer out = response.getWriter()) {
|
|
||||||
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_CLIENT)) {
|
@Override
|
||||||
response.setStatus(HttpStatus.UNAUTHORIZED.value());
|
public Authentication convert(HttpServletRequest request) {
|
||||||
|
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
|
||||||
|
|
||||||
|
// grant_type (REQUIRED)
|
||||||
|
String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
|
||||||
|
if (!StringUtils.hasText(grantType) ||
|
||||||
|
parameters.get(OAuth2ParameterNames.GRANT_TYPE).size() != 1) {
|
||||||
|
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE);
|
||||||
|
}
|
||||||
|
if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {
|
||||||
|
throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// client_id (REQUIRED)
|
||||||
|
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
|
||||||
|
Authentication clientPrincipal = null;
|
||||||
|
if (StringUtils.hasText(clientId)) {
|
||||||
|
if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
|
||||||
|
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
response.setStatus(HttpStatus.BAD_REQUEST.value());
|
clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
|
||||||
}
|
}
|
||||||
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
|
|
||||||
response.setCharacterEncoding("UTF-8");
|
// code (REQUIRED)
|
||||||
out.write(objectMapper.writeValueAsString(error));
|
String code = parameters.getFirst(OAuth2ParameterNames.CODE);
|
||||||
|
if (!StringUtils.hasText(code) ||
|
||||||
|
parameters.get(OAuth2ParameterNames.CODE).size() != 1) {
|
||||||
|
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CODE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirect_uri (REQUIRED)
|
||||||
|
// Required only if the "redirect_uri" parameter was included in the authorization request
|
||||||
|
String redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
|
||||||
|
if (StringUtils.hasText(redirectUri) &&
|
||||||
|
parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
|
||||||
|
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientPrincipal != null ?
|
||||||
|
new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) :
|
||||||
|
new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,36 +15,47 @@
|
|||||||
*/
|
*/
|
||||||
package org.springframework.security.oauth2.server.authorization.web;
|
package org.springframework.security.oauth2.server.authorization.web;
|
||||||
|
|
||||||
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.mockito.ArgumentCaptor;
|
||||||
import org.springframework.http.HttpStatus;
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.http.converter.HttpMessageConverter;
|
||||||
|
import org.springframework.mock.http.client.MockClientHttpResponse;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletResponse;
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
import org.springframework.security.authentication.AuthenticationManager;
|
import org.springframework.security.authentication.AuthenticationManager;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
|
import org.springframework.security.core.context.SecurityContext;
|
||||||
|
import org.springframework.security.core.context.SecurityContextHolder;
|
||||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
|
||||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
|
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
|
||||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||||
import org.springframework.security.oauth2.server.authorization.TokenType;
|
|
||||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
|
||||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||||
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
||||||
|
|
||||||
import javax.servlet.FilterChain;
|
import javax.servlet.FilterChain;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
import java.time.Duration;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.Mockito.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.verifyNoInteractions;
|
import static org.mockito.Mockito.verifyNoInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
@ -53,178 +64,222 @@ import static org.mockito.Mockito.when;
|
|||||||
* Tests for {@link OAuth2TokenEndpointFilter}.
|
* Tests for {@link OAuth2TokenEndpointFilter}.
|
||||||
*
|
*
|
||||||
* @author Madhu Bhat
|
* @author Madhu Bhat
|
||||||
|
* @author Joe Grandja
|
||||||
*/
|
*/
|
||||||
public class OAuth2TokenEndpointFilterTests {
|
public class OAuth2TokenEndpointFilterTests {
|
||||||
|
private AuthenticationManager authenticationManager;
|
||||||
|
private OAuth2AuthorizationService authorizationService;
|
||||||
private OAuth2TokenEndpointFilter filter;
|
private OAuth2TokenEndpointFilter filter;
|
||||||
private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
|
private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
|
||||||
private AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
|
new OAuth2ErrorHttpMessageConverter();
|
||||||
private FilterChain filterChain = mock(FilterChain.class);
|
private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
|
||||||
private String requestUri;
|
new OAuth2AccessTokenResponseHttpMessageConverter();
|
||||||
private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
|
|
||||||
private static final String PRINCIPAL_NAME = "principal";
|
|
||||||
private static final String AUTHORIZATION_CODE = "code";
|
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
this.filter = new OAuth2TokenEndpointFilter(this.authorizationService, this.authenticationManager);
|
this.authenticationManager = mock(AuthenticationManager.class);
|
||||||
this.requestUri = "/oauth2/token";
|
this.authorizationService = mock(OAuth2AuthorizationService.class);
|
||||||
|
this.filter = new OAuth2TokenEndpointFilter(this.authenticationManager, this.authorizationService);
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void cleanup() {
|
||||||
|
SecurityContextHolder.clearContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void constructorServiceAndManagerWhenNullThenThrowIllegalArgumentException() {
|
public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() {
|
||||||
assertThatThrownBy(() -> {
|
assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(null, this.authorizationService))
|
||||||
new OAuth2TokenEndpointFilter(null, null);
|
.isInstanceOf(IllegalArgumentException.class)
|
||||||
}).isInstanceOf(IllegalArgumentException.class);
|
.hasMessage("authenticationManager cannot be null");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void constructorServiceAndManagerAndEndpointWhenNullThenThrowIllegalArgumentException() {
|
public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
|
||||||
assertThatThrownBy(() -> {
|
assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(this.authenticationManager, null))
|
||||||
new OAuth2TokenEndpointFilter(null, null, null);
|
.isInstanceOf(IllegalArgumentException.class)
|
||||||
}).isInstanceOf(IllegalArgumentException.class);
|
.hasMessage("authorizationService cannot be null");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenNotTokenRequestThenNextFilter() throws Exception {
|
public void constructorWhenTokenEndpointUriNullThenThrowIllegalArgumentException() {
|
||||||
this.requestUri = "/path";
|
assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(this.authenticationManager, this.authorizationService, null))
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", this.requestUri);
|
.isInstanceOf(IllegalArgumentException.class)
|
||||||
request.setServletPath(this.requestUri);
|
.hasMessage("tokenEndpointUri cannot be empty");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenNotTokenRequestThenNotProcessed() throws Exception {
|
||||||
|
String requestUri = "/path";
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
|
||||||
|
request.setServletPath(requestUri);
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, this.filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAccessTokenRequestWithoutGrantTypeThenRespondWithBadRequest() throws Exception {
|
public void doFilterWhenTokenRequestGetThenNotProcessed() throws Exception {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
|
String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI;
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
|
request.setServletPath(requestUri);
|
||||||
request.setServletPath(this.requestUri);
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, this.filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
verifyNoInteractions(this.filterChain);
|
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
|
||||||
assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAccessTokenRequestWithoutCodeThenRespondWithBadRequest() throws Exception {
|
public void doFilterWhenTokenRequestMissingGrantTypeThenInvalidRequestError() throws Exception {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
||||||
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType");
|
OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST,
|
||||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
|
request -> request.removeParameter(OAuth2ParameterNames.GRANT_TYPE));
|
||||||
request.setServletPath(this.requestUri);
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
||||||
|
|
||||||
this.filter.doFilter(request, response, this.filterChain);
|
|
||||||
|
|
||||||
verifyNoInteractions(this.filterChain);
|
|
||||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
|
||||||
assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAccessTokenRequestWithoutRedirectUriThenRespondWithBadRequest() throws Exception {
|
public void doFilterWhenTokenRequestMultipleGrantTypeThenInvalidRequestError() throws Exception {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
||||||
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType");
|
OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST,
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
|
request -> request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()));
|
||||||
request.setServletPath(this.requestUri);
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
||||||
|
|
||||||
this.filter.doFilter(request, response, this.filterChain);
|
|
||||||
|
|
||||||
verifyNoInteractions(this.filterChain);
|
|
||||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
|
||||||
assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAccessTokenRequestWithoutAuthCodeGrantTypeThenRespondWithBadRequest() throws Exception {
|
public void doFilterWhenTokenRequestInvalidGrantTypeThenUnsupportedGrantTypeError() throws Exception {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
||||||
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType");
|
OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE,
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
|
request -> request.setParameter(OAuth2ParameterNames.GRANT_TYPE, "invalid-grant-type"));
|
||||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
|
|
||||||
request.setServletPath(this.requestUri);
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
||||||
|
|
||||||
this.filter.doFilter(request, response, this.filterChain);
|
|
||||||
|
|
||||||
verifyNoInteractions(this.filterChain);
|
|
||||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
|
||||||
assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"unsupported_grant_type\"}");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAccessTokenRequestIsNotAuthenticatedThenRespondWithUnauthorized() throws Exception {
|
public void doFilterWhenTokenRequestMultipleClientIdThenInvalidRequestError() throws Exception {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
||||||
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
|
OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST,
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
|
request -> {
|
||||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
|
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1");
|
||||||
request.setServletPath(this.requestUri);
|
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
});
|
||||||
Authentication clientPrincipal = mock(Authentication.class);
|
}
|
||||||
RegisteredClient registeredClient = mock(RegisteredClient.class);
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception {
|
||||||
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
||||||
|
OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST,
|
||||||
|
request -> request.removeParameter(OAuth2ParameterNames.CODE));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception {
|
||||||
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
||||||
|
OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST,
|
||||||
|
request -> request.addParameter(OAuth2ParameterNames.CODE, "code-2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
|
||||||
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
||||||
|
OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST,
|
||||||
|
request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception {
|
||||||
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||||
|
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
|
||||||
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
||||||
OAuth2AccessToken.TokenType.BEARER, "testToken", Instant.now().minusSeconds(60), Instant.now());
|
OAuth2AccessToken.TokenType.BEARER, "token",
|
||||||
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
|
Instant.now(), Instant.now().plus(Duration.ofHours(1)),
|
||||||
.principalName(PRINCIPAL_NAME)
|
new HashSet<>(Arrays.asList("scope1", "scope2")));
|
||||||
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
|
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
|
||||||
.build();
|
new OAuth2AccessTokenAuthenticationToken(
|
||||||
OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken =
|
registeredClient, clientPrincipal, accessToken);
|
||||||
new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
|
|
||||||
accessTokenAuthenticationToken.setAuthenticated(false);
|
|
||||||
|
|
||||||
when(this.authorizationService.findByTokenAndTokenType(anyString(), any(TokenType.class))).thenReturn(authorization);
|
when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);
|
||||||
when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(accessTokenAuthenticationToken);
|
|
||||||
|
|
||||||
this.filter.doFilter(request, response, this.filterChain);
|
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
||||||
|
securityContext.setAuthentication(clientPrincipal);
|
||||||
|
SecurityContextHolder.setContext(securityContext);
|
||||||
|
|
||||||
verifyNoInteractions(this.filterChain);
|
MockHttpServletRequest request = createTokenRequest(registeredClient);
|
||||||
verify(this.authorizationService, times(0)).save(authorization);
|
|
||||||
verify(this.authenticationManager, times(1)).authenticate(any(Authentication.class));
|
|
||||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
|
|
||||||
assertThat(response.getContentAsString())
|
|
||||||
.isEqualTo("{\"errorCode\":\"invalid_client\"}");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void doFilterWhenValidAccessTokenRequestThenRespondWithAccessToken() throws Exception {
|
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
|
|
||||||
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
|
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
|
|
||||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
|
|
||||||
request.setServletPath(this.requestUri);
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
Authentication clientPrincipal = mock(Authentication.class);
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
RegisteredClient registeredClient = mock(RegisteredClient.class);
|
|
||||||
|
|
||||||
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
this.filter.doFilter(request, response, filterChain);
|
||||||
OAuth2AccessToken.TokenType.BEARER, "testToken", Instant.now().minusSeconds(60), Instant.now());
|
|
||||||
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
|
|
||||||
.principalName(PRINCIPAL_NAME)
|
|
||||||
.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
|
|
||||||
.build();
|
|
||||||
OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken =
|
|
||||||
new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
|
|
||||||
accessTokenAuthenticationToken.setAuthenticated(true);
|
|
||||||
|
|
||||||
when(this.authorizationService.findByTokenAndTokenType(anyString(), any(TokenType.class))).thenReturn(authorization);
|
verifyNoInteractions(filterChain);
|
||||||
when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(accessTokenAuthenticationToken);
|
|
||||||
|
|
||||||
this.filter.doFilter(request, response, this.filterChain);
|
ArgumentCaptor<OAuth2AuthorizationCodeAuthenticationToken> authorizationCodeAuthenticationCaptor =
|
||||||
|
ArgumentCaptor.forClass(OAuth2AuthorizationCodeAuthenticationToken.class);
|
||||||
|
verify(this.authenticationManager).authenticate(authorizationCodeAuthenticationCaptor.capture());
|
||||||
|
|
||||||
|
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
|
||||||
|
authorizationCodeAuthenticationCaptor.getValue();
|
||||||
|
assertThat(authorizationCodeAuthentication.getCode()).isEqualTo(
|
||||||
|
request.getParameter(OAuth2ParameterNames.CODE));
|
||||||
|
assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
|
||||||
|
assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo(
|
||||||
|
request.getParameter(OAuth2ParameterNames.REDIRECT_URI));
|
||||||
|
|
||||||
verifyNoInteractions(this.filterChain);
|
|
||||||
verify(this.authorizationService, times(1)).save(authorization);
|
|
||||||
verify(this.authenticationManager, times(1)).authenticate(any(Authentication.class));
|
|
||||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
|
||||||
assertThat(response.getContentAsString()).contains("\"tokenValue\":\"testToken\"");
|
OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
|
||||||
assertThat(response.getContentAsString()).contains("\"tokenType\":{\"value\":\"Bearer\"}");
|
|
||||||
assertThat(response.getHeader(HttpHeaders.CACHE_CONTROL)).isEqualTo("no-store");
|
OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken();
|
||||||
assertThat(response.getHeader(HttpHeaders.PRAGMA)).isEqualTo("no-cache");
|
assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType());
|
||||||
|
assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue());
|
||||||
|
assertThat(accessTokenResult.getIssuedAt()).isBetween(
|
||||||
|
accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1));
|
||||||
|
assertThat(accessTokenResult.getExpiresAt()).isBetween(
|
||||||
|
accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1));
|
||||||
|
assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode,
|
||||||
|
Consumer<MockHttpServletRequest> requestConsumer) throws Exception {
|
||||||
|
|
||||||
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||||
|
|
||||||
|
MockHttpServletRequest request = createTokenRequest(registeredClient);
|
||||||
|
requestConsumer.accept(request);
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verifyNoInteractions(filterChain);
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||||
|
OAuth2Error error = readError(response);
|
||||||
|
assertThat(error.getErrorCode()).isEqualTo(errorCode);
|
||||||
|
assertThat(error.getDescription()).isEqualTo("OAuth 2.0 Parameter: " + parameterName);
|
||||||
|
}
|
||||||
|
|
||||||
|
private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
|
||||||
|
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
|
||||||
|
response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
|
||||||
|
return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
private OAuth2AccessTokenResponse readAccessTokenResponse(MockHttpServletResponse response) throws Exception {
|
||||||
|
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
|
||||||
|
response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
|
||||||
|
return this.accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static MockHttpServletRequest createTokenRequest(RegisteredClient registeredClient) {
|
||||||
|
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);
|
||||||
|
|
||||||
|
String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI;
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
|
||||||
|
request.setServletPath(requestUri);
|
||||||
|
|
||||||
|
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
|
||||||
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||||
|
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
|
||||||
|
|
||||||
|
return request;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user