Polish OAuth2AuthorizationEndpointFilterTests

Issue gh-77
This commit is contained in:
Joe Grandja 2020-06-27 20:49:31 -04:00
parent 2a3dfd953d
commit 02b64f0ef0

View File

@ -26,6 +26,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -41,6 +42,7 @@ import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Set;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -130,53 +132,29 @@ public class OAuth2AuthorizationEndpointFilterTests {
@Test
public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
doFilterWhenAuthorizationRequestInvalidParameterThenError(
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID));
}
@Test
public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
doFilterWhenAuthorizationRequestInvalidParameterThenError(
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"));
}
@Test
public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
doFilterWhenAuthorizationRequestInvalidParameterThenError(
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"));
}
@Test
@ -188,16 +166,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id");
doFilterWhenAuthorizationRequestInvalidParameterThenError(
registeredClient,
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.UNAUTHORIZED_CLIENT);
}
@Test
@ -206,17 +178,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
doFilterWhenAuthorizationRequestInvalidParameterThenError(
registeredClient,
OAuth2ParameterNames.REDIRECT_URI,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"));
}
@Test
@ -225,17 +191,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
doFilterWhenAuthorizationRequestInvalidParameterThenError(
registeredClient,
OAuth2ParameterNames.REDIRECT_URI,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
}
@Test
@ -244,17 +204,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
.thenReturn(registeredClient);
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.removeParameter(OAuth2ParameterNames.REDIRECT_URI);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
doFilterWhenAuthorizationRequestInvalidParameterThenError(
registeredClient,
OAuth2ParameterNames.REDIRECT_URI,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.removeParameter(OAuth2ParameterNames.REDIRECT_URI));
}
@Test
@ -383,6 +337,27 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
}
private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
String parameterName, String errorCode) throws Exception {
doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {});
}
private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
String parameterName, String errorCode, Consumer<MockHttpServletRequest> requestConsumer) throws Exception {
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
requestConsumer.accept(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName);
}
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);