diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java
index 8a13b07..b87c590 100644
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java
+++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java
@@ -15,6 +15,7 @@
*/
package org.springframework.security.oauth2.server.authorization;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import java.util.List;
@@ -65,7 +66,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) {
if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
- return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue()));
+ return token.equals(authorization.getAttributes().get(OAuth2ParameterNames.class.getName().concat(".CODE")));
} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
return authorization.getAccessToken() != null &&
authorization.getAccessToken().getTokenValue().equals(token);
diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
index 0b5713d..2391c23 100644
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
+++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
@@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
@@ -196,7 +197,7 @@ public class OAuth2Authorization implements Serializable {
*/
public OAuth2Authorization build() {
Assert.hasText(this.principalName, "principalName cannot be empty");
- Assert.notNull(this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()), "authorization code cannot be null");
+ Assert.notNull(this.attributes.get(OAuth2ParameterNames.class.getName().concat(".CODE")), "authorization code cannot be null");
OAuth2Authorization authorization = new OAuth2Authorization();
authorization.registeredClientId = this.registeredClientId;
diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
index 270e92d..aac79d1 100644
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
+++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
@@ -15,29 +15,21 @@
*/
package org.springframework.security.oauth2.server.authorization.web;
-import java.io.IOException;
-import java.util.stream.Stream;
-
-import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
-import org.springframework.core.convert.converter.Converter;
+import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
-import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.web.DefaultRedirectStrategy;
@@ -45,201 +37,257 @@ import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponentsBuilder;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
/**
+ * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
+ * which handles the processing of the OAuth 2.0 Authorization Request.
+ *
* @author Joe Grandja
* @author Paurav Munshi
* @since 0.0.1
+ * @see RegisteredClientRepository
+ * @see OAuth2AuthorizationService
+ * @see OAuth2Authorization
+ * @see Section 4.1 Authorization Code Grant
+ * @see Section 4.1.1 Authorization Request
*/
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
+ /**
+ * The default endpoint {@code URI} for authorization requests.
+ */
+ public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
- private static final String DEFAULT_ENDPOINT = "/oauth2/authorize";
-
- private Converter authorizationRequestConverter = new OAuth2AuthorizationRequestConverter();
- private RegisteredClientRepository registeredClientRepository;
- private OAuth2AuthorizationService authorizationService;
- private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator();
- private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
- private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT);
+ private final RegisteredClientRepository registeredClientRepository;
+ private final OAuth2AuthorizationService authorizationService;
+ private final RequestMatcher authorizationEndpointMatcher;
+ private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
+ private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
+ /**
+ * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
+ *
+ * @param registeredClientRepository the repository of registered clients
+ * @param authorizationService the authorization service
+ */
public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
OAuth2AuthorizationService authorizationService) {
- Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null.");
- Assert.notNull(authorizationService, "authorizationService cannot be null.");
+ this(registeredClientRepository, authorizationService, DEFAULT_AUTHORIZATION_ENDPOINT_URI);
+ }
+
+ /**
+ * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
+ *
+ * @param registeredClientRepository the repository of registered clients
+ * @param authorizationService the authorization service
+ * @param authorizationEndpointUri the endpoint {@code URI} for authorization requests
+ */
+ public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
+ OAuth2AuthorizationService authorizationService, String authorizationEndpointUri) {
+ Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+ Assert.notNull(authorizationService, "authorizationService cannot be null");
+ Assert.hasText(authorizationEndpointUri, "authorizationEndpointUri cannot be empty");
this.registeredClientRepository = registeredClientRepository;
this.authorizationService = authorizationService;
- }
-
- public final void setAuthorizationRequestConverter(
- Converter authorizationRequestConverter) {
- Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null");
- this.authorizationRequestConverter = authorizationRequestConverter;
- }
-
- public final void setCodeGenerator(StringKeyGenerator codeGenerator) {
- Assert.notNull(codeGenerator, "codeGenerator cannot be set to null");
- this.codeGenerator = codeGenerator;
- }
-
- public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
- Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null");
- this.authorizationRedirectStrategy = authorizationRedirectStrategy;
- }
-
- public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) {
- Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null");
- this.authorizationEndpointMatcher = authorizationEndpointMatcher;
+ this.authorizationEndpointMatcher = new AntPathRequestMatcher(
+ authorizationEndpointUri, HttpMethod.GET.name());
}
@Override
- protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
- boolean pathMatch = this.authorizationEndpointMatcher.matches(request);
- String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
- boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType);
- if (pathMatch && responseTypeMatch) {
- return false;
- }else {
- return true;
- }
- }
-
- @Override
- protected void doFilterInternal(HttpServletRequest request,
- HttpServletResponse response, FilterChain filterChain)
+ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
- RegisteredClient client = null;
- OAuth2AuthorizationRequest authorizationRequest = null;
- OAuth2Authorization authorization = null;
-
- try {
- checkUserAuthenticated();
- Authentication auth = SecurityContextHolder.getContext().getAuthentication();
- client = fetchRegisteredClient(request);
-
- authorizationRequest = this.authorizationRequestConverter.convert(request);
- validateAuthorizationRequest(authorizationRequest, client);
-
- String code = this.codeGenerator.generateKey();
- authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code);
- this.authorizationService.save(authorization);
-
- String redirectUri = getRedirectUri(authorizationRequest, client);
- sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code);
+ if (!this.authorizationEndpointMatcher.matches(request) || !isPrincipalAuthenticated()) {
+ filterChain.doFilter(request, response);
+ return;
}
- catch(OAuth2AuthorizationException authorizationException) {
- OAuth2Error authorizationError = authorizationException.getError();
- if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
- || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
- sendErrorInResponse(response, authorizationError);
- }
- else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)
- || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) {
- String redirectUri = getRedirectUri(authorizationRequest, client);
- sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri);
- }
- else {
- throw new ServletException(authorizationException);
+// TODO
+// The authorization server validates the request to ensure that all
+// required parameters are present and valid. If the request is valid,
+// the authorization server authenticates the resource owner and obtains
+// an authorization decision (by asking the resource owner or by
+// establishing approval via other means).
+
+ MultiValueMap parameters = getParameters(request);
+ String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
+
+ // client_id (REQUIRED)
+ String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
+ if (!StringUtils.hasText(clientId) ||
+ parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
+ OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
+ sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
+ return;
+ }
+ RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
+ if (registeredClient == null) {
+ OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
+ sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
+ return;
+ } else if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
+ OAuth2Error error = createError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID);
+ sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
+ return;
+ }
+
+ // redirect_uri (OPTIONAL)
+ String redirectUriParameter = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
+ if (StringUtils.hasText(redirectUriParameter)) {
+ if (!registeredClient.getRedirectUris().contains(redirectUriParameter) ||
+ parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
+ OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
+ sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
+ return;
}
+ } else if (registeredClient.getRedirectUris().size() != 1) {
+ OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
+ sendErrorResponse(request, response, error, stateParameter, null); // when redirectUri is null then don't redirect
+ return;
}
+ String redirectUri = StringUtils.hasText(redirectUriParameter) ?
+ redirectUriParameter : registeredClient.getRedirectUris().iterator().next();
+
+ // response_type (REQUIRED)
+ String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);
+ if (!StringUtils.hasText(responseType) ||
+ parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) {
+ OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE);
+ sendErrorResponse(request, response, error, stateParameter, redirectUri);
+ return;
+ } else if (!responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) {
+ OAuth2Error error = createError(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2ParameterNames.RESPONSE_TYPE);
+ sendErrorResponse(request, response, error, stateParameter, redirectUri);
+ return;
+ }
+
+ Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+ String code = this.codeGenerator.generateKey();
+ OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest(request);
+
+ OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
+ .principalName(principal.getName())
+ .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), code)
+ .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
+ .build();
+
+ this.authorizationService.save(authorization);
+
+// TODO security checks for code parameter
+// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks.
+// A maximum authorization code lifetime of 10 minutes is RECOMMENDED.
+// The client MUST NOT use the authorization code more than once.
+// If an authorization code is used more than once, the authorization server MUST deny the request
+// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code.
+// The authorization code is bound to the client identifier and redirection URI.
+
+ sendAuthorizationResponse(request, response, authorizationRequest, code, redirectUri);
}
- private void checkUserAuthenticated() {
- Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication();
- if (currentAuth==null || !currentAuth.isAuthenticated()) {
- throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
- }
- }
+ private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response,
+ OAuth2AuthorizationRequest authorizationRequest, String code, String redirectUri) throws IOException {
- private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException {
- String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
- if (StringUtils.isEmpty(clientId)) {
- throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
- }
-
- RegisteredClient client = this.registeredClientRepository.findByClientId(clientId);
- if (client==null) {
- throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
- }
-
- boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes())
- .anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE));
- if (!isAuthorizationGrantAllowed) {
- throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
- }
-
- return client;
-
- }
-
- private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client,
- OAuth2AuthorizationRequest authorizationRequest, String code) {
- OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client)
- .principalName(auth.getPrincipal().toString())
- .attribute(TokenType.AUTHORIZATION_CODE.getValue(), code)
- .attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes()))
- .build();
-
- return authorization;
- }
-
-
- private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
- String redirectUri = authorizationRequest.getRedirectUri();
- if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) {
- throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
- }
- if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) {
- throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
- }
- }
-
- private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
- return !StringUtils.isEmpty(authorizationRequest.getRedirectUri())
- ? authorizationRequest.getRedirectUri()
- : client.getRedirectUris().stream().findFirst().get();
- }
-
- private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response,
- OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException {
- UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
+ UriComponentsBuilder uriBuilder = UriComponentsBuilder
+ .fromUriString(redirectUri)
.queryParam(OAuth2ParameterNames.CODE, code);
- if (!StringUtils.isEmpty(authorizationRequest.getState())) {
- redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
+ if (StringUtils.hasText(authorizationRequest.getState())) {
+ uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
}
-
- String finalRedirectUri = redirectUriBuilder.toUriString();
- this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri);
+ this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
}
- private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException {
- int errorStatus = -1;
- String errorCode = authorizationError.getErrorCode();
- if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
- errorStatus=HttpStatus.FORBIDDEN.value();
+ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
+ OAuth2Error error, String state, String redirectUri) throws IOException {
+
+ if (redirectUri == null) {
+ // TODO Send default html error response
+ response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString());
+ return;
}
- else {
- errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value();
+
+ UriComponentsBuilder uriBuilder = UriComponentsBuilder
+ .fromUriString(redirectUri)
+ .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
+ if (StringUtils.hasText(error.getDescription())) {
+ uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
}
- response.sendError(errorStatus, authorizationError.getErrorCode());
+ if (StringUtils.hasText(error.getUri())) {
+ uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
+ }
+ if (StringUtils.hasText(state)) {
+ uriBuilder.queryParam(OAuth2ParameterNames.STATE, state);
+ }
+ this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
}
- private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response,
- OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError,
- String redirectUri) throws IOException {
- UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
- .queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode());
+ private static boolean isPrincipalAuthenticated() {
+ return isPrincipalAuthenticated(SecurityContextHolder.getContext().getAuthentication());
+ }
- if (!StringUtils.isEmpty(authorizationRequest.getState())) {
- redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
+ private static boolean isPrincipalAuthenticated(Authentication principal) {
+ return principal != null &&
+ !AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) &&
+ principal.isAuthenticated();
+ }
+
+ private static OAuth2Error createError(String errorCode, String parameterName) {
+ return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
+ "https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
+ }
+
+ private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
+ MultiValueMap parameters = getParameters(request);
+
+ Set scopes = Collections.emptySet();
+ if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) {
+ String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
+ scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
}
- String finalRedirectURI = redirectUriBuilder.toUriString();
- this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI);
+ return OAuth2AuthorizationRequest.authorizationCode()
+ .authorizationUri(request.getRequestURL().toString())
+ .clientId(parameters.getFirst(OAuth2ParameterNames.CLIENT_ID))
+ .redirectUri(parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI))
+ .scopes(scopes)
+ .state(parameters.getFirst(OAuth2ParameterNames.STATE))
+ .additionalParameters(additionalParameters ->
+ parameters.entrySet().stream()
+ .filter(e -> !e.getKey().equals(OAuth2ParameterNames.RESPONSE_TYPE) &&
+ !e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
+ !e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI) &&
+ !e.getKey().equals(OAuth2ParameterNames.SCOPE) &&
+ !e.getKey().equals(OAuth2ParameterNames.STATE))
+ .forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0))))
+ .build();
+ }
+
+ private static MultiValueMap getParameters(HttpServletRequest request) {
+ Map parameterMap = request.getParameterMap();
+ MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size());
+ parameterMap.forEach((key, values) -> {
+ if (values.length > 0) {
+ for (String value : values) {
+ parameters.add(key, value);
+ }
+ }
+ });
+ return parameters;
}
}
diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java
deleted file mode 100644
index 619a742..0000000
--- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.web;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.LinkedHashSet;
-import java.util.Set;
-
-import javax.servlet.http.HttpServletRequest;
-
-import org.springframework.core.convert.converter.Converter;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.util.StringUtils;
-
-/**
- * @author Paurav Munshi
- * @since 0.0.1
- * @see Converter
- */
-public class OAuth2AuthorizationRequestConverter implements Converter {
-
- @Override
- public OAuth2AuthorizationRequest convert(HttpServletRequest request) {
- String scope = request.getParameter(OAuth2ParameterNames.SCOPE);
- Set scopes = !StringUtils.isEmpty(scope)
- ? new LinkedHashSet(Arrays.asList(scope.split(" ")))
- : Collections.emptySet();
-
- OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
- .clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID))
- .redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
- .scopes(scopes)
- .state(request.getParameter(OAuth2ParameterNames.STATE))
- .authorizationUri(request.getServletPath())
- .build();
-
- return authorizationRequest;
- }
-
-}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java
index d1fb762..f4bc5bd 100644
--- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java
@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
@@ -61,7 +62,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void saveWhenAuthorizationProvidedThenSaved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
- .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+ .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.build();
this.authorizationService.save(expectedAuthorization);
@@ -88,7 +89,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
- .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+ .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.build();
this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
@@ -103,7 +104,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
"access-token", Instant.now().minusSeconds(60), Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
- .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+ .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.accessToken(accessToken)
.build();
this.authorizationService.save(authorization);
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
index a35efcc..dfd3d90 100644
--- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization;
import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
@@ -84,13 +85,13 @@ public class OAuth2AuthorizationTests {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME)
.accessToken(ACCESS_TOKEN)
- .attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+ .attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
.build();
assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
assertThat(authorization.getAttributes()).containsExactly(
- entry(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE));
+ entry(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE));
}
}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java
index 502fa48..5aa7bc0 100644
--- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java
@@ -46,40 +46,4 @@ public class TestRegisteredClients {
.scope("profile")
.scope("email");
}
-
- public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() {
- return RegisteredClient.withId("valid_client_id")
- .clientId("valid_client")
- .clientSecret("valid_secret")
- .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
- .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
- .redirectUri("http://localhost:8080/test-application/callback")
- .scope("openid")
- .scope("profile")
- .scope("email");
- }
-
- public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() {
- return RegisteredClient.withId("valid_client_multi_uri_id")
- .clientId("valid_client_multi_uri")
- .clientSecret("valid_secret")
- .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
- .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
- .redirectUri("http://localhost:8080/test-application/callback")
- .redirectUri("http://localhost:8080/another-test-application/callback")
- .scope("openid")
- .scope("profile")
- .scope("email");
- }
-
- public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() {
- return RegisteredClient.withId("valid_cc_client_id")
- .clientId("valid_cc_client")
- .clientSecret("valid_secret")
- .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
- .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
- .scope("openid")
- .scope("profile")
- .scope("email");
- }
}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java
deleted file mode 100644
index 4b2e86a..0000000
--- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java
+++ /dev/null
@@ -1,371 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.web;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-import javax.servlet.FilterChain;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.springframework.http.HttpStatus;
-import org.springframework.mock.web.MockHttpServletRequest;
-import org.springframework.mock.web.MockHttpServletResponse;
-import org.springframework.security.core.Authentication;
-import org.springframework.security.core.context.SecurityContextHolder;
-import org.springframework.security.crypto.keygen.StringKeyGenerator;
-import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
-import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
-
-
-/**
- * Tests for {@link OAuth2AuthorizationEndpointFilter}.
- *
- * @author Paurav Munshi
- * @since 0.0.1
- */
-
-public class OAuth2AuthorizationEndpointFilterTest {
-
- private static final String VALID_CLIENT = "valid_client";
- private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri";
- private static final String VALID_CC_CLIENT = "valid_cc_client";
-
- private OAuth2AuthorizationEndpointFilter filter;
-
- private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
- private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class);
- private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class);
- private Authentication authentication = mock(Authentication.class);
-
- @Before
- public void setUp() {
- this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
- this.filter.setCodeGenerator(this.codeGenerator);
-
- SecurityContextHolder.getContext().setAuthentication(this.authentication);
- }
-
- @Test
- public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
- assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
- .isInstanceOf(IllegalArgumentException.class);
- }
-
- @Test
- public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
- assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
- .isInstanceOf(IllegalArgumentException.class);
- }
-
- @Test
- public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
- assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null))
- .isInstanceOf(IllegalArgumentException.class);
- }
-
- @Test
- public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
- assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null))
- .isInstanceOf(IllegalArgumentException.class);
- }
-
- @Test
- public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
- assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null))
- .isInstanceOf(IllegalArgumentException.class);
- }
-
- @Test
- public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
- assertThatThrownBy(() ->this.filter.setCodeGenerator(null))
- .isInstanceOf(IllegalArgumentException.class);
- }
-
- @Test
- public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
- when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
- when(this.codeGenerator.generateKey()).thenReturn("sample_code");
- when(this.authentication.getPrincipal()).thenReturn("test-user");
- when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication).isAuthenticated();
- verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
- verify(this.authorizationService).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
- assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
-
- }
-
- @Test
- public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
- when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
- when(this.codeGenerator.generateKey()).thenReturn("sample_code");
- when(this.authentication.getPrincipal()).thenReturn("test-user");
- when(this.authentication.isAuthenticated()).thenReturn(true);
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication).isAuthenticated();
- verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
- verify(this.authorizationService).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
- assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
-
- }
-
- @Test
- public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI);
- request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build();
- when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient);
- when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication, times(1)).isAuthenticated();
- verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI);
- verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator, times(0)).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
- assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-
- }
-
- @Test
- public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
- when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
- when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication, times(1)).isAuthenticated();
- verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT);
- verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator, times(0)).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
- assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-
- }
-
- @Test
- public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT);
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build();
- when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient);
- when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication, times(1)).isAuthenticated();
- verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT);
- verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator, times(0)).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
- assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
-
- }
-
- @Test
- public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.CLIENT_ID, "");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- when(this.authentication.isAuthenticated()).thenReturn(true);
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication).isAuthenticated();
- verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
- verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator, times(0)).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
- assertThat(response.getContentAsString()).isEmpty();
- assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-
- }
-
- @Test
- public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null);
- when(this.codeGenerator.generateKey()).thenReturn("sample_code");
- when(this.authentication.isAuthenticated()).thenReturn(true);
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication).isAuthenticated();
- verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client");
- verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator, times(0)).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
- assertThat(response.getContentAsString()).isEmpty();
- assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
-
- }
-
- @Test
- public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- when(authentication.isAuthenticated()).thenReturn(false);
-
- this.filter.doFilter(request, response, filterChain);
-
- verify(this.authentication).isAuthenticated();
- verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
- verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
- verify(this.codeGenerator, times(0)).generateKey();
- verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
- assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
- assertThat(response.getContentAsString()).isEmpty();
- assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
-
- }
-
- @Test
- public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setServletPath("/custom/authorize");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
- spyFilter.doFilter(request, response, filterChain);
-
- verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
- verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
- verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
- }
-
- @Test
- public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
- spyFilter.doFilter(request, response, filterChain);
-
- verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
- verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
- verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
- }
-
- @Test
- public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
- MockHttpServletRequest request = getValidMockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token");
- MockHttpServletResponse response = new MockHttpServletResponse();
- FilterChain filterChain = mock(FilterChain.class);
-
- OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
- spyFilter.doFilter(request, response, filterChain);
-
- verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
- verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
- verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
- }
-
- private MockHttpServletRequest getValidMockHttpServletRequest() {
-
- MockHttpServletRequest request = new MockHttpServletRequest();
- request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT);
- request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code");
- request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email");
- request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback");
- request.setParameter(OAuth2ParameterNames.STATE, "teststate");
- request.setServletPath("/oauth2/authorize");
-
- return request;
-
-
- }
-
-}
diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java
new file mode 100644
index 0000000..af95436
--- /dev/null
+++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java
@@ -0,0 +1,399 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.util.StringUtils;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OAuth2AuthorizationEndpointFilter}.
+ *
+ * @author Paurav Munshi
+ * @author Joe Grandja
+ * @since 0.0.1
+ */
+public class OAuth2AuthorizationEndpointFilterTests {
+ private RegisteredClientRepository registeredClientRepository;
+ private OAuth2AuthorizationService authorizationService;
+ private OAuth2AuthorizationEndpointFilter filter;
+ private TestingAuthenticationToken authentication;
+
+ @Before
+ public void setUp() {
+ this.registeredClientRepository = mock(RegisteredClientRepository.class);
+ this.authorizationService = mock(OAuth2AuthorizationService.class);
+ this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
+ this.authentication = new TestingAuthenticationToken("principalName", "password");
+ this.authentication.setAuthenticated(true);
+ SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+ securityContext.setAuthentication(this.authentication);
+ SecurityContextHolder.setContext(securityContext);
+ }
+
+ @After
+ public void cleanup() {
+ SecurityContextHolder.clearContext();
+ }
+
+ @Test
+ public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("registeredClientRepository cannot be null");
+ }
+
+ @Test
+ public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("authorizationService cannot be null");
+ }
+
+ @Test
+ public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("authorizationEndpointUri cannot be empty");
+ }
+
+ @Test
+ public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception {
+ String requestUri = "/path";
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setServletPath(requestUri);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Exception {
+ String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
+ MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
+ request.setServletPath(requestUri);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestNotAuthenticatedThenNotProcessed() throws Exception {
+ String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setServletPath(requestUri);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.authentication.setAuthenticated(false);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid");
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestAndClientNotAuthorizedToRequestCodeThenUnauthorizedClientError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+ .authorizationGrantTypes(Set::clear)
+ .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+ .build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestInvalidRedirectUriThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com");
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestExcludesRedirectUriAndMultipleRegisteredThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().redirectUri("https://example2.com").build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.removeParameter(OAuth2ParameterNames.REDIRECT_URI);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestMissingResponseTypeThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+ assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+ "error=invalid_request&" +
+ "error_description=OAuth%202.0%20Parameter:%20response_type&" +
+ "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
+ "state=state");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestMultipleResponseTypeThenInvalidRequestError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+ assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+ "error=invalid_request&" +
+ "error_description=OAuth%202.0%20Parameter:%20response_type&" +
+ "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
+ "state=state");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestInvalidResponseTypeThenUnsupportedResponseTypeError() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+ assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+ "error=unsupported_response_type&" +
+ "error_description=OAuth%202.0%20Parameter:%20response_type&" +
+ "error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
+ "state=state");
+ }
+
+ @Test
+ public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() throws Exception {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+ .thenReturn(registeredClient);
+
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ FilterChain filterChain = mock(FilterChain.class);
+
+ this.filter.doFilter(request, response, filterChain);
+
+ verifyNoInteractions(filterChain);
+
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+ assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
+
+ ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+
+ verify(this.authorizationService).save(authorizationCaptor.capture());
+
+ OAuth2Authorization authorization = authorizationCaptor.getValue();
+ assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
+ assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+
+ String code = authorization.getAttribute(OAuth2ParameterNames.class.getName().concat(".CODE"));
+ assertThat(code).isNotNull();
+
+ OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+ assertThat(authorizationRequest).isNotNull();
+ assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize");
+ assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+ assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
+ assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(registeredClient.getRedirectUris().iterator().next());
+ assertThat(authorizationRequest.getScopes()).containsExactlyInAnyOrderElementsOf(registeredClient.getScopes());
+ assertThat(authorizationRequest.getState()).isEqualTo("state");
+ assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
+ }
+
+ private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
+ String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);
+
+ String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setServletPath(requestUri);
+
+ request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
+ request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+ request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
+ request.addParameter(OAuth2ParameterNames.SCOPE,
+ StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+ return request;
+ }
+}