Add Authorization Endpoint filter

Fixes gh-66
This commit is contained in:
Paurav Munshi 2020-04-29 23:29:20 -04:00 committed by Joe Grandja
parent 26c3941a20
commit 54e219a397
4 changed files with 669 additions and 9 deletions

View File

@ -15,33 +15,231 @@
*/ */
package org.springframework.security.oauth2.server.authorization.web; package org.springframework.security.oauth2.server.authorization.web;
import org.springframework.core.convert.converter.Converter; import java.io.IOException;
import org.springframework.security.crypto.keygen.StringKeyGenerator; import java.util.stream.Stream;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponentsBuilder;
/** /**
* @author Joe Grandja * @author Joe Grandja
* @author Paurav Munshi
* @since 0.0.1
*/ */
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter;
private static final String DEFAULT_ENDPOINT = "/oauth2/authorize";
private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter = new OAuth2AuthorizationRequestConverter();
private RegisteredClientRepository registeredClientRepository; private RegisteredClientRepository registeredClientRepository;
private OAuth2AuthorizationService authorizationService; private OAuth2AuthorizationService authorizationService;
private StringKeyGenerator codeGenerator; private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator();
private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT);
public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
OAuth2AuthorizationService authorizationService) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null.");
Assert.notNull(authorizationService, "authorizationService cannot be null.");
this.registeredClientRepository = registeredClientRepository;
this.authorizationService = authorizationService;
}
public final void setAuthorizationRequestConverter(
Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter) {
Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null");
this.authorizationRequestConverter = authorizationRequestConverter;
}
public final void setCodeGenerator(StringKeyGenerator codeGenerator) {
Assert.notNull(codeGenerator, "codeGenerator cannot be set to null");
this.codeGenerator = codeGenerator;
}
public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null");
this.authorizationRedirectStrategy = authorizationRedirectStrategy;
}
public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) {
Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null");
this.authorizationEndpointMatcher = authorizationEndpointMatcher;
}
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
boolean pathMatch = this.authorizationEndpointMatcher.matches(request);
String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType);
if (pathMatch && responseTypeMatch) {
return false;
}else {
return true;
}
}
@Override @Override
protected void doFilterInternal(HttpServletRequest request, protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain) HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
RegisteredClient client = null;
OAuth2AuthorizationRequest authorizationRequest = null;
OAuth2Authorization authorization = null;
try {
checkUserAuthenticated();
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
client = fetchRegisteredClient(request);
authorizationRequest = this.authorizationRequestConverter.convert(request);
validateAuthorizationRequest(authorizationRequest, client);
String code = this.codeGenerator.generateKey();
authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code);
this.authorizationService.save(authorization);
String redirectUri = getRedirectUri(authorizationRequest, client);
sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code);
}
catch(OAuth2AuthorizationException authorizationException) {
OAuth2Error authorizationError = authorizationException.getError();
if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
sendErrorInResponse(response, authorizationError);
}
else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)
|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) {
String redirectUri = getRedirectUri(authorizationRequest, client);
sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri);
}
else {
throw new ServletException(authorizationException);
}
}
} }
private void checkUserAuthenticated() {
Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication();
if (currentAuth==null || !currentAuth.isAuthenticated()) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
}
}
private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException {
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
if (StringUtils.isEmpty(clientId)) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}
RegisteredClient client = this.registeredClientRepository.findByClientId(clientId);
if (client==null) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
}
boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes())
.anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE));
if (!isAuthorizationGrantAllowed) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
}
return client;
}
private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client,
OAuth2AuthorizationRequest authorizationRequest, String code) {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client)
.principalName(auth.getPrincipal().toString())
.attribute(TokenType.AUTHORIZATION_CODE.getValue(), code)
.attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes()))
.build();
return authorization;
}
private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
String redirectUri = authorizationRequest.getRedirectUri();
if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}
if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) {
throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}
}
private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
return !StringUtils.isEmpty(authorizationRequest.getRedirectUri())
? authorizationRequest.getRedirectUri()
: client.getRedirectUris().stream().findFirst().get();
}
private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response,
OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException {
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
.queryParam(OAuth2ParameterNames.CODE, code);
if (!StringUtils.isEmpty(authorizationRequest.getState())) {
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
}
String finalRedirectUri = redirectUriBuilder.toUriString();
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri);
}
private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException {
int errorStatus = -1;
String errorCode = authorizationError.getErrorCode();
if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
errorStatus=HttpStatus.FORBIDDEN.value();
}
else {
errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value();
}
response.sendError(errorStatus, authorizationError.getErrorCode());
}
private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response,
OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError,
String redirectUri) throws IOException {
UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
.queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode());
if (!StringUtils.isEmpty(authorizationRequest.getState())) {
redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
}
String finalRedirectURI = redirectUriBuilder.toUriString();
this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI);
}
} }

View File

@ -0,0 +1,55 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.http.HttpServletRequest;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.StringUtils;
/**
* @author Paurav Munshi
* @since 0.0.1
* @see Converter
*/
public class OAuth2AuthorizationRequestConverter implements Converter<HttpServletRequest, OAuth2AuthorizationRequest> {
@Override
public OAuth2AuthorizationRequest convert(HttpServletRequest request) {
String scope = request.getParameter(OAuth2ParameterNames.SCOPE);
Set<String> scopes = !StringUtils.isEmpty(scope)
? new LinkedHashSet<String>(Arrays.asList(scope.split(" ")))
: Collections.emptySet();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID))
.redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
.scopes(scopes)
.state(request.getParameter(OAuth2ParameterNames.STATE))
.authorizationUri(request.getServletPath())
.build();
return authorizationRequest;
}
}

View File

@ -46,4 +46,40 @@ public class TestRegisteredClients {
.scope("profile") .scope("profile")
.scope("email"); .scope("email");
} }
public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() {
return RegisteredClient.withId("valid_client_id")
.clientId("valid_client")
.clientSecret("valid_secret")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.redirectUri("http://localhost:8080/test-application/callback")
.scope("openid")
.scope("profile")
.scope("email");
}
public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() {
return RegisteredClient.withId("valid_client_multi_uri_id")
.clientId("valid_client_multi_uri")
.clientSecret("valid_secret")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.redirectUri("http://localhost:8080/test-application/callback")
.redirectUri("http://localhost:8080/another-test-application/callback")
.scope("openid")
.scope("profile")
.scope("email");
}
public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() {
return RegisteredClient.withId("valid_cc_client_id")
.clientId("valid_cc_client")
.clientSecret("valid_secret")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.scope("openid")
.scope("profile")
.scope("email");
}
} }

View File

@ -0,0 +1,371 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
/**
* Tests for {@link OAuth2AuthorizationEndpointFilter}.
*
* @author Paurav Munshi
* @since 0.0.1
*/
public class OAuth2AuthorizationEndpointFilterTest {
private static final String VALID_CLIENT = "valid_client";
private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri";
private static final String VALID_CC_CLIENT = "valid_cc_client";
private OAuth2AuthorizationEndpointFilter filter;
private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class);
private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class);
private Authentication authentication = mock(Authentication.class);
@Before
public void setUp() {
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
this.filter.setCodeGenerator(this.codeGenerator);
SecurityContextHolder.getContext().setAuthentication(this.authentication);
}
@Test
public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
assertThatThrownBy(() ->this.filter.setCodeGenerator(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
when(this.authentication.getPrincipal()).thenReturn("test-user");
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
verify(this.authorizationService).save(any(OAuth2Authorization.class));
verify(this.codeGenerator).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
}
@Test
public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
when(this.authentication.getPrincipal()).thenReturn("test-user");
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
verify(this.authorizationService).save(any(OAuth2Authorization.class));
verify(this.codeGenerator).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
}
@Test
public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI);
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication, times(1)).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI);
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication, times(1)).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT);
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build();
when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication, times(1)).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT);
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
}
@Test
public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
assertThat(response.getContentAsString()).isEmpty();
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null);
when(this.codeGenerator.generateKey()).thenReturn("sample_code");
when(this.authentication.isAuthenticated()).thenReturn(true);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client");
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
assertThat(response.getContentAsString()).isEmpty();
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
}
@Test
public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
when(authentication.isAuthenticated()).thenReturn(false);
this.filter.doFilter(request, response, filterChain);
verify(this.authentication).isAuthenticated();
verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
verify(this.codeGenerator, times(0)).generateKey();
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
assertThat(response.getContentAsString()).isEmpty();
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
}
@Test
public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setServletPath("/custom/authorize");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
spyFilter.doFilter(request, response, filterChain);
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
}
@Test
public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
spyFilter.doFilter(request, response, filterChain);
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
MockHttpServletRequest request = getValidMockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
spyFilter.doFilter(request, response, filterChain);
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
private MockHttpServletRequest getValidMockHttpServletRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT);
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code");
request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email");
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback");
request.setParameter(OAuth2ParameterNames.STATE, "teststate");
request.setServletPath("/oauth2/authorize");
return request;
}
}