Add Authorization Endpoint filter
Fixes gh-66
This commit is contained in:
parent
26c3941a20
commit
54e219a397
@ -15,33 +15,231 @@
|
|||||||
*/
|
*/
|
||||||
package org.springframework.security.oauth2.server.authorization.web;
|
package org.springframework.security.oauth2.server.authorization.web;
|
||||||
|
|
||||||
import org.springframework.core.convert.converter.Converter;
|
import java.io.IOException;
|
||||||
import org.springframework.security.crypto.keygen.StringKeyGenerator;
|
import java.util.stream.Stream;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
||||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
|
||||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
|
||||||
import org.springframework.web.filter.OncePerRequestFilter;
|
|
||||||
|
|
||||||
import javax.servlet.FilterChain;
|
import javax.servlet.FilterChain;
|
||||||
import javax.servlet.ServletException;
|
import javax.servlet.ServletException;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import java.io.IOException;
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.security.core.Authentication;
|
||||||
|
import org.springframework.security.core.context.SecurityContextHolder;
|
||||||
|
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
|
||||||
|
import org.springframework.security.crypto.keygen.StringKeyGenerator;
|
||||||
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.TokenType;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||||
|
import org.springframework.security.web.DefaultRedirectStrategy;
|
||||||
|
import org.springframework.security.web.RedirectStrategy;
|
||||||
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||||
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Joe Grandja
|
* @author Joe Grandja
|
||||||
|
* @author Paurav Munshi
|
||||||
|
* @since 0.0.1
|
||||||
*/
|
*/
|
||||||
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
|
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
|
||||||
private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter;
|
|
||||||
|
private static final String DEFAULT_ENDPOINT = "/oauth2/authorize";
|
||||||
|
|
||||||
|
private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter = new OAuth2AuthorizationRequestConverter();
|
||||||
private RegisteredClientRepository registeredClientRepository;
|
private RegisteredClientRepository registeredClientRepository;
|
||||||
private OAuth2AuthorizationService authorizationService;
|
private OAuth2AuthorizationService authorizationService;
|
||||||
private StringKeyGenerator codeGenerator;
|
private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator();
|
||||||
|
private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
|
||||||
|
private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT);
|
||||||
|
|
||||||
|
public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
|
||||||
|
OAuth2AuthorizationService authorizationService) {
|
||||||
|
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null.");
|
||||||
|
Assert.notNull(authorizationService, "authorizationService cannot be null.");
|
||||||
|
this.registeredClientRepository = registeredClientRepository;
|
||||||
|
this.authorizationService = authorizationService;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void setAuthorizationRequestConverter(
|
||||||
|
Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter) {
|
||||||
|
Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null");
|
||||||
|
this.authorizationRequestConverter = authorizationRequestConverter;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void setCodeGenerator(StringKeyGenerator codeGenerator) {
|
||||||
|
Assert.notNull(codeGenerator, "codeGenerator cannot be set to null");
|
||||||
|
this.codeGenerator = codeGenerator;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
|
||||||
|
Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null");
|
||||||
|
this.authorizationRedirectStrategy = authorizationRedirectStrategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) {
|
||||||
|
Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null");
|
||||||
|
this.authorizationEndpointMatcher = authorizationEndpointMatcher;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
|
||||||
|
boolean pathMatch = this.authorizationEndpointMatcher.matches(request);
|
||||||
|
String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
|
||||||
|
boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType);
|
||||||
|
if (pathMatch && responseTypeMatch) {
|
||||||
|
return false;
|
||||||
|
}else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doFilterInternal(HttpServletRequest request,
|
protected void doFilterInternal(HttpServletRequest request,
|
||||||
HttpServletResponse response, FilterChain filterChain)
|
HttpServletResponse response, FilterChain filterChain)
|
||||||
throws ServletException, IOException {
|
throws ServletException, IOException {
|
||||||
|
|
||||||
|
RegisteredClient client = null;
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = null;
|
||||||
|
OAuth2Authorization authorization = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
checkUserAuthenticated();
|
||||||
|
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
|
||||||
|
client = fetchRegisteredClient(request);
|
||||||
|
|
||||||
|
authorizationRequest = this.authorizationRequestConverter.convert(request);
|
||||||
|
validateAuthorizationRequest(authorizationRequest, client);
|
||||||
|
|
||||||
|
String code = this.codeGenerator.generateKey();
|
||||||
|
authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code);
|
||||||
|
this.authorizationService.save(authorization);
|
||||||
|
|
||||||
|
String redirectUri = getRedirectUri(authorizationRequest, client);
|
||||||
|
sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code);
|
||||||
|
}
|
||||||
|
catch(OAuth2AuthorizationException authorizationException) {
|
||||||
|
OAuth2Error authorizationError = authorizationException.getError();
|
||||||
|
|
||||||
|
if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
|
||||||
|
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
|
||||||
|
sendErrorInResponse(response, authorizationError);
|
||||||
|
}
|
||||||
|
else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)
|
||||||
|
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) {
|
||||||
|
String redirectUri = getRedirectUri(authorizationRequest, client);
|
||||||
|
sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw new ServletException(authorizationException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void checkUserAuthenticated() {
|
||||||
|
Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication();
|
||||||
|
if (currentAuth==null || !currentAuth.isAuthenticated()) {
|
||||||
|
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException {
|
||||||
|
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
|
||||||
|
if (StringUtils.isEmpty(clientId)) {
|
||||||
|
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
RegisteredClient client = this.registeredClientRepository.findByClientId(clientId);
|
||||||
|
if (client==null) {
|
||||||
|
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes())
|
||||||
|
.anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE));
|
||||||
|
if (!isAuthorizationGrantAllowed) {
|
||||||
|
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
|
||||||
|
}
|
||||||
|
|
||||||
|
return client;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client,
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest, String code) {
|
||||||
|
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client)
|
||||||
|
.principalName(auth.getPrincipal().toString())
|
||||||
|
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), code)
|
||||||
|
.attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes()))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return authorization;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
|
||||||
|
String redirectUri = authorizationRequest.getRedirectUri();
|
||||||
|
if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) {
|
||||||
|
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) {
|
||||||
|
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
|
||||||
|
return !StringUtils.isEmpty(authorizationRequest.getRedirectUri())
|
||||||
|
? authorizationRequest.getRedirectUri()
|
||||||
|
: client.getRedirectUris().stream().findFirst().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response,
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException {
|
||||||
|
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
|
||||||
|
.queryParam(OAuth2ParameterNames.CODE, code);
|
||||||
|
if (!StringUtils.isEmpty(authorizationRequest.getState())) {
|
||||||
|
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
|
||||||
|
}
|
||||||
|
|
||||||
|
String finalRedirectUri = redirectUriBuilder.toUriString();
|
||||||
|
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException {
|
||||||
|
int errorStatus = -1;
|
||||||
|
String errorCode = authorizationError.getErrorCode();
|
||||||
|
if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
|
||||||
|
errorStatus=HttpStatus.FORBIDDEN.value();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value();
|
||||||
|
}
|
||||||
|
response.sendError(errorStatus, authorizationError.getErrorCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response,
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError,
|
||||||
|
String redirectUri) throws IOException {
|
||||||
|
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
|
||||||
|
.queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode());
|
||||||
|
|
||||||
|
if (!StringUtils.isEmpty(authorizationRequest.getState())) {
|
||||||
|
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
|
||||||
|
}
|
||||||
|
|
||||||
|
String finalRedirectURI = redirectUriBuilder.toUriString();
|
||||||
|
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.springframework.security.oauth2.server.authorization.web;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Paurav Munshi
|
||||||
|
* @since 0.0.1
|
||||||
|
* @see Converter
|
||||||
|
*/
|
||||||
|
public class OAuth2AuthorizationRequestConverter implements Converter<HttpServletRequest, OAuth2AuthorizationRequest> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OAuth2AuthorizationRequest convert(HttpServletRequest request) {
|
||||||
|
String scope = request.getParameter(OAuth2ParameterNames.SCOPE);
|
||||||
|
Set<String> scopes = !StringUtils.isEmpty(scope)
|
||||||
|
? new LinkedHashSet<String>(Arrays.asList(scope.split(" ")))
|
||||||
|
: Collections.emptySet();
|
||||||
|
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
|
||||||
|
.clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID))
|
||||||
|
.redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
|
||||||
|
.scopes(scopes)
|
||||||
|
.state(request.getParameter(OAuth2ParameterNames.STATE))
|
||||||
|
.authorizationUri(request.getServletPath())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return authorizationRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -46,4 +46,40 @@ public class TestRegisteredClients {
|
|||||||
.scope("profile")
|
.scope("profile")
|
||||||
.scope("email");
|
.scope("email");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() {
|
||||||
|
return RegisteredClient.withId("valid_client_id")
|
||||||
|
.clientId("valid_client")
|
||||||
|
.clientSecret("valid_secret")
|
||||||
|
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||||
|
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||||
|
.redirectUri("http://localhost:8080/test-application/callback")
|
||||||
|
.scope("openid")
|
||||||
|
.scope("profile")
|
||||||
|
.scope("email");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() {
|
||||||
|
return RegisteredClient.withId("valid_client_multi_uri_id")
|
||||||
|
.clientId("valid_client_multi_uri")
|
||||||
|
.clientSecret("valid_secret")
|
||||||
|
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||||
|
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||||
|
.redirectUri("http://localhost:8080/test-application/callback")
|
||||||
|
.redirectUri("http://localhost:8080/another-test-application/callback")
|
||||||
|
.scope("openid")
|
||||||
|
.scope("profile")
|
||||||
|
.scope("email");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() {
|
||||||
|
return RegisteredClient.withId("valid_cc_client_id")
|
||||||
|
.clientId("valid_cc_client")
|
||||||
|
.clientSecret("valid_secret")
|
||||||
|
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
|
||||||
|
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||||
|
.scope("openid")
|
||||||
|
.scope("profile")
|
||||||
|
.scope("email");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,371 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2020 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.springframework.security.oauth2.server.authorization.web;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.spy;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import javax.servlet.FilterChain;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
|
import org.springframework.security.core.Authentication;
|
||||||
|
import org.springframework.security.core.context.SecurityContextHolder;
|
||||||
|
import org.springframework.security.crypto.keygen.StringKeyGenerator;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||||
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests for {@link OAuth2AuthorizationEndpointFilter}.
|
||||||
|
*
|
||||||
|
* @author Paurav Munshi
|
||||||
|
* @since 0.0.1
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class OAuth2AuthorizationEndpointFilterTest {
|
||||||
|
|
||||||
|
private static final String VALID_CLIENT = "valid_client";
|
||||||
|
private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri";
|
||||||
|
private static final String VALID_CC_CLIENT = "valid_cc_client";
|
||||||
|
|
||||||
|
private OAuth2AuthorizationEndpointFilter filter;
|
||||||
|
|
||||||
|
private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
|
||||||
|
private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class);
|
||||||
|
private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class);
|
||||||
|
private Authentication authentication = mock(Authentication.class);
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
|
||||||
|
this.filter.setCodeGenerator(this.codeGenerator);
|
||||||
|
|
||||||
|
SecurityContextHolder.getContext().setAuthentication(this.authentication);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
|
||||||
|
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
|
||||||
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
|
||||||
|
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
|
||||||
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||||
|
assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null))
|
||||||
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||||
|
assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null))
|
||||||
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||||
|
assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null))
|
||||||
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
|
||||||
|
assertThatThrownBy(() ->this.filter.setCodeGenerator(null))
|
||||||
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
|
||||||
|
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
|
||||||
|
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
|
||||||
|
when(this.authentication.getPrincipal()).thenReturn("test-user");
|
||||||
|
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||||
|
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
|
||||||
|
verify(this.authorizationService).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||||
|
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
|
||||||
|
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
|
||||||
|
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
|
||||||
|
when(this.authentication.getPrincipal()).thenReturn("test-user");
|
||||||
|
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
|
||||||
|
verify(this.authorizationService).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||||
|
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI);
|
||||||
|
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build();
|
||||||
|
when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient);
|
||||||
|
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||||
|
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication, times(1)).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI);
|
||||||
|
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator, times(0)).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
|
||||||
|
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
|
||||||
|
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
|
||||||
|
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||||
|
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication, times(1)).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT);
|
||||||
|
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator, times(0)).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
|
||||||
|
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT);
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build();
|
||||||
|
when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient);
|
||||||
|
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||||
|
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication, times(1)).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT);
|
||||||
|
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator, times(0)).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
|
||||||
|
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
|
||||||
|
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator, times(0)).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
|
||||||
|
assertThat(response.getContentAsString()).isEmpty();
|
||||||
|
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null);
|
||||||
|
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
|
||||||
|
when(this.authentication.isAuthenticated()).thenReturn(true);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client");
|
||||||
|
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator, times(0)).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
|
||||||
|
assertThat(response.getContentAsString()).isEmpty();
|
||||||
|
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
when(authentication.isAuthenticated()).thenReturn(false);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(this.authentication).isAuthenticated();
|
||||||
|
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
|
||||||
|
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
|
||||||
|
verify(this.codeGenerator, times(0)).generateKey();
|
||||||
|
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
|
||||||
|
assertThat(response.getContentAsString()).isEmpty();
|
||||||
|
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setServletPath("/custom/authorize");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
|
||||||
|
spyFilter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
|
||||||
|
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
|
||||||
|
spyFilter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
|
||||||
|
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
|
||||||
|
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
|
||||||
|
MockHttpServletRequest request = getValidMockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
|
||||||
|
spyFilter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
|
||||||
|
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
|
||||||
|
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
private MockHttpServletRequest getValidMockHttpServletRequest() {
|
||||||
|
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT);
|
||||||
|
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code");
|
||||||
|
request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email");
|
||||||
|
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback");
|
||||||
|
request.setParameter(OAuth2ParameterNames.STATE, "teststate");
|
||||||
|
request.setServletPath("/oauth2/authorize");
|
||||||
|
|
||||||
|
return request;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user