diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 874a056..3f4a2a6 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -26,6 +26,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -41,6 +42,7 @@ import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.util.Set; +import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -130,53 +132,29 @@ public class OAuth2AuthorizationEndpointFilterTests { @Test public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.removeParameter(OAuth2ParameterNames.CLIENT_ID); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); + doFilterWhenAuthorizationRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID)); } @Test public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); + doFilterWhenAuthorizationRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2")); } @Test public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); + doFilterWhenAuthorizationRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid")); } @Test @@ -188,16 +166,10 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id"); + doFilterWhenAuthorizationRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.UNAUTHORIZED_CLIENT); } @Test @@ -206,17 +178,11 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); + doFilterWhenAuthorizationRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.REDIRECT_URI, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com")); } @Test @@ -225,17 +191,11 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); + doFilterWhenAuthorizationRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.REDIRECT_URI, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com")); } @Test @@ -244,17 +204,11 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) .thenReturn(registeredClient); - MockHttpServletRequest request = createAuthorizationRequest(registeredClient); - request.removeParameter(OAuth2ParameterNames.REDIRECT_URI); - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verifyNoInteractions(filterChain); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); + doFilterWhenAuthorizationRequestInvalidParameterThenError( + registeredClient, + OAuth2ParameterNames.REDIRECT_URI, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(OAuth2ParameterNames.REDIRECT_URI)); } @Test @@ -383,6 +337,27 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); } + private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, + String parameterName, String errorCode) throws Exception { + doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {}); + } + + private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, + String parameterName, String errorCode, Consumer requestConsumer) throws Exception { + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + requestConsumer.accept(request); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName); + } + private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);