Polish OAuth2AuthorizationEndpointFilterTests

Issue gh-77
This commit is contained in:
Joe Grandja 2020-06-27 20:49:31 -04:00
parent 2a3dfd953d
commit 02b64f0ef0

View File

@ -26,6 +26,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -41,6 +42,7 @@ import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -130,53 +132,29 @@ public class OAuth2AuthorizationEndpointFilterTests {
@Test @Test
public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception { public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); doFilterWhenAuthorizationRequestInvalidParameterThenError(
TestRegisteredClients.registeredClient().build(),
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); OAuth2ParameterNames.CLIENT_ID,
request.removeParameter(OAuth2ParameterNames.CLIENT_ID); OAuth2ErrorCodes.INVALID_REQUEST,
MockHttpServletResponse response = new MockHttpServletResponse(); request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID));
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
} }
@Test @Test
public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception { public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); doFilterWhenAuthorizationRequestInvalidParameterThenError(
TestRegisteredClients.registeredClient().build(),
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); OAuth2ParameterNames.CLIENT_ID,
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); OAuth2ErrorCodes.INVALID_REQUEST,
MockHttpServletResponse response = new MockHttpServletResponse(); request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"));
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
} }
@Test @Test
public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception { public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); doFilterWhenAuthorizationRequestInvalidParameterThenError(
TestRegisteredClients.registeredClient().build(),
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); OAuth2ParameterNames.CLIENT_ID,
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"); OAuth2ErrorCodes.INVALID_REQUEST,
MockHttpServletResponse response = new MockHttpServletResponse(); request -> request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"));
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
} }
@Test @Test
@ -188,16 +166,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient); .thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); doFilterWhenAuthorizationRequestInvalidParameterThenError(
MockHttpServletResponse response = new MockHttpServletResponse(); registeredClient,
FilterChain filterChain = mock(FilterChain.class); OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.UNAUTHORIZED_CLIENT);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id");
} }
@Test @Test
@ -206,17 +178,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient); .thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); doFilterWhenAuthorizationRequestInvalidParameterThenError(
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"); registeredClient,
MockHttpServletResponse response = new MockHttpServletResponse(); OAuth2ParameterNames.REDIRECT_URI,
FilterChain filterChain = mock(FilterChain.class); OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"));
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
} }
@Test @Test
@ -225,17 +191,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient); .thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); doFilterWhenAuthorizationRequestInvalidParameterThenError(
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); registeredClient,
MockHttpServletResponse response = new MockHttpServletResponse(); OAuth2ParameterNames.REDIRECT_URI,
FilterChain filterChain = mock(FilterChain.class); OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
} }
@Test @Test
@ -244,17 +204,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient); .thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); doFilterWhenAuthorizationRequestInvalidParameterThenError(
request.removeParameter(OAuth2ParameterNames.REDIRECT_URI); registeredClient,
MockHttpServletResponse response = new MockHttpServletResponse(); OAuth2ParameterNames.REDIRECT_URI,
FilterChain filterChain = mock(FilterChain.class); OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.removeParameter(OAuth2ParameterNames.REDIRECT_URI));
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
} }
@Test @Test
@ -383,6 +337,27 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
} }
private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
String parameterName, String errorCode) throws Exception {
doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {});
}
private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
String parameterName, String errorCode, Consumer<MockHttpServletRequest> requestConsumer) throws Exception {
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
requestConsumer.accept(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName);
}
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);