diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index c8a143c..270e92d 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -15,33 +15,231 @@ */ package org.springframework.security.oauth2.server.authorization.web; -import org.springframework.core.convert.converter.Converter; -import org.springframework.security.crypto.keygen.StringKeyGenerator; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; -import org.springframework.web.filter.OncePerRequestFilter; +import java.io.IOException; +import java.util.stream.Stream; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.IOException; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpStatus; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.util.UriComponentsBuilder; /** * @author Joe Grandja + * @author Paurav Munshi + * @since 0.0.1 */ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { - private Converter authorizationRequestConverter; + + private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; + + private Converter authorizationRequestConverter = new OAuth2AuthorizationRequestConverter(); private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; - private StringKeyGenerator codeGenerator; + private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(); + private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); + private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT); + + public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null."); + Assert.notNull(authorizationService, "authorizationService cannot be null."); + this.registeredClientRepository = registeredClientRepository; + this.authorizationService = authorizationService; + } + + public final void setAuthorizationRequestConverter( + Converter authorizationRequestConverter) { + Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null"); + this.authorizationRequestConverter = authorizationRequestConverter; + } + + public final void setCodeGenerator(StringKeyGenerator codeGenerator) { + Assert.notNull(codeGenerator, "codeGenerator cannot be set to null"); + this.codeGenerator = codeGenerator; + } + + public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) { + Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null"); + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + } + + public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) { + Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null"); + this.authorizationEndpointMatcher = authorizationEndpointMatcher; + } + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + boolean pathMatch = this.authorizationEndpointMatcher.matches(request); + String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); + boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType); + if (pathMatch && responseTypeMatch) { + return false; + }else { + return true; + } + } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + RegisteredClient client = null; + OAuth2AuthorizationRequest authorizationRequest = null; + OAuth2Authorization authorization = null; + + try { + checkUserAuthenticated(); + Authentication auth = SecurityContextHolder.getContext().getAuthentication(); + client = fetchRegisteredClient(request); + + authorizationRequest = this.authorizationRequestConverter.convert(request); + validateAuthorizationRequest(authorizationRequest, client); + + String code = this.codeGenerator.generateKey(); + authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code); + this.authorizationService.save(authorization); + + String redirectUri = getRedirectUri(authorizationRequest, client); + sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code); + } + catch(OAuth2AuthorizationException authorizationException) { + OAuth2Error authorizationError = authorizationException.getError(); + + if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) { + sendErrorInResponse(response, authorizationError); + } + else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) { + String redirectUri = getRedirectUri(authorizationRequest, client); + sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri); + } + else { + throw new ServletException(authorizationException); + } + } + } + private void checkUserAuthenticated() { + Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication(); + if (currentAuth==null || !currentAuth.isAuthenticated()) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); + } + } + + private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { + String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); + if (StringUtils.isEmpty(clientId)) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } + + RegisteredClient client = this.registeredClientRepository.findByClientId(clientId); + if (client==null) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); + } + + boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes()) + .anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE)); + if (!isAuthorizationGrantAllowed) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); + } + + return client; + + } + + private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client, + OAuth2AuthorizationRequest authorizationRequest, String code) { + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client) + .principalName(auth.getPrincipal().toString()) + .attribute(TokenType.AUTHORIZATION_CODE.getValue(), code) + .attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes())) + .build(); + + return authorization; + } + + + private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { + String redirectUri = authorizationRequest.getRedirectUri(); + if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } + if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } + } + + private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { + return !StringUtils.isEmpty(authorizationRequest.getRedirectUri()) + ? authorizationRequest.getRedirectUri() + : client.getRedirectUris().stream().findFirst().get(); + } + + private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response, + OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException { + UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) + .queryParam(OAuth2ParameterNames.CODE, code); + if (!StringUtils.isEmpty(authorizationRequest.getState())) { + redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + } + + String finalRedirectUri = redirectUriBuilder.toUriString(); + this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri); + } + + private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { + int errorStatus = -1; + String errorCode = authorizationError.getErrorCode(); + if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) { + errorStatus=HttpStatus.FORBIDDEN.value(); + } + else { + errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value(); + } + response.sendError(errorStatus, authorizationError.getErrorCode()); + } + + private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, + OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError, + String redirectUri) throws IOException { + UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) + .queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode()); + + if (!StringUtils.isEmpty(authorizationRequest.getState())) { + redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + } + + String finalRedirectURI = redirectUriBuilder.toUriString(); + this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); + } } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java new file mode 100644 index 0000000..619a742 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java @@ -0,0 +1,55 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.StringUtils; + +/** + * @author Paurav Munshi + * @since 0.0.1 + * @see Converter + */ +public class OAuth2AuthorizationRequestConverter implements Converter { + + @Override + public OAuth2AuthorizationRequest convert(HttpServletRequest request) { + String scope = request.getParameter(OAuth2ParameterNames.SCOPE); + Set scopes = !StringUtils.isEmpty(scope) + ? new LinkedHashSet(Arrays.asList(scope.split(" "))) + : Collections.emptySet(); + + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID)) + .redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)) + .scopes(scopes) + .state(request.getParameter(OAuth2ParameterNames.STATE)) + .authorizationUri(request.getServletPath()) + .build(); + + return authorizationRequest; + } + +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java index 5aa7bc0..502fa48 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -46,4 +46,40 @@ public class TestRegisteredClients { .scope("profile") .scope("email"); } + + public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() { + return RegisteredClient.withId("valid_client_id") + .clientId("valid_client") + .clientSecret("valid_secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("http://localhost:8080/test-application/callback") + .scope("openid") + .scope("profile") + .scope("email"); + } + + public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() { + return RegisteredClient.withId("valid_client_multi_uri_id") + .clientId("valid_client_multi_uri") + .clientSecret("valid_secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("http://localhost:8080/test-application/callback") + .redirectUri("http://localhost:8080/another-test-application/callback") + .scope("openid") + .scope("profile") + .scope("email"); + } + + public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() { + return RegisteredClient.withId("valid_cc_client_id") + .clientId("valid_cc_client") + .clientSecret("valid_secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .scope("openid") + .scope("profile") + .scope("email"); + } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java new file mode 100644 index 0000000..4b2e86a --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -0,0 +1,371 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + + +/** + * Tests for {@link OAuth2AuthorizationEndpointFilter}. + * + * @author Paurav Munshi + * @since 0.0.1 + */ + +public class OAuth2AuthorizationEndpointFilterTest { + + private static final String VALID_CLIENT = "valid_client"; + private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri"; + private static final String VALID_CC_CLIENT = "valid_cc_client"; + + private OAuth2AuthorizationEndpointFilter filter; + + private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); + private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class); + private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class); + private Authentication authentication = mock(Authentication.class); + + @Before + public void setUp() { + this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService); + this.filter.setCodeGenerator(this.codeGenerator); + + SecurityContextHolder.getContext().setAuthentication(this.authentication); + } + + @Test + public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { + assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { + assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setCodeGenerator(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.getPrincipal()).thenReturn("test-user"); + when(this.authentication.isAuthenticated()).thenReturn(true); + + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); + + } + + @Test + public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.getPrincipal()).thenReturn("test-user"); + when(this.authentication.isAuthenticated()).thenReturn(true); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); + + } + + @Test + public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build(); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient); + when(this.authentication.isAuthenticated()).thenReturn(true); + + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication, times(1)).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + + } + + @Test + public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.authentication.isAuthenticated()).thenReturn(true); + + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication, times(1)).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + + } + + @Test + public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build(); + when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient); + when(this.authentication.isAuthenticated()).thenReturn(true); + + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication, times(1)).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); + + } + + @Test + public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + when(this.authentication.isAuthenticated()).thenReturn(true); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); + assertThat(response.getContentAsString()).isEmpty(); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + + } + + @Test + public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.isAuthenticated()).thenReturn(true); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client"); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); + assertThat(response.getContentAsString()).isEmpty(); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); + + } + + @Test + public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + when(authentication.isAuthenticated()).thenReturn(false); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); + assertThat(response.getContentAsString()).isEmpty(); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); + + } + + @Test + public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setServletPath("/custom/authorize"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); + spyFilter.doFilter(request, response, filterChain); + + verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); + verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + } + + @Test + public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); + spyFilter.doFilter(request, response, filterChain); + + verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); + verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); + spyFilter.doFilter(request, response, filterChain); + + verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); + verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + private MockHttpServletRequest getValidMockHttpServletRequest() { + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT); + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code"); + request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email"); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback"); + request.setParameter(OAuth2ParameterNames.STATE, "teststate"); + request.setServletPath("/oauth2/authorize"); + + return request; + + + } + +}