diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index d20c5d5..eada2ac 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.security.Principal; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -145,7 +146,8 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica JoseHeader.Builder headersBuilder = JwtUtils.headers(); JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims( - registeredClient, issuer, authorization.getPrincipalName(), authorizedScopes); + registeredClient, issuer, authorization.getPrincipalName(), + excludeOpenidIfNecessary(authorizedScopes)); // @formatter:off JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder) @@ -167,7 +169,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(), - jwtAccessToken.getExpiresAt(), authorizedScopes); + jwtAccessToken.getExpiresAt(), excludeOpenidIfNecessary(authorizedScopes)); OAuth2RefreshToken refreshToken = null; if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) { @@ -243,6 +245,15 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica registeredClient, clientPrincipal, accessToken, refreshToken, additionalParameters); } + private static Set excludeOpenidIfNecessary(Set scopes) { + if (!scopes.contains(OidcScopes.OPENID)) { + return scopes; + } + scopes = new HashSet<>(scopes); + scopes.remove(OidcScopes.OPENID); + return scopes; + } + @Override public boolean supports(Class authentication) { return OAuth2AuthorizationCodeAuthenticationToken.class.isAssignableFrom(authentication); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index b603e24..7a14f26 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -198,7 +198,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { .attribute(Principal.class.getName(), principal) .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest); - if (registeredClient.getClientSettings().requireUserConsent()) { + if (requireUserConsent(registeredClient, authorizationRequest)) { String state = this.stateGenerator.generateKey(); OAuth2Authorization authorization = builder .attribute(OAuth2ParameterNames.STATE, state) @@ -232,6 +232,15 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { } } + private static boolean requireUserConsent(RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) { + // openid scope does not require consent + if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) && + authorizationRequest.getScopes().size() == 1) { + return false; + } + return registeredClient.getClientSettings().requireUserConsent(); + } + private void processUserConsent(HttpServletRequest request, HttpServletResponse response) throws IOException { @@ -264,11 +273,16 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES); // TODO Allow configuration for authorization code time-to-live OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( this.codeGenerator.generateKey(), issuedAt, expiresAt); + Set authorizedScopes = userConsentRequestContext.getScopes(); + if (userConsentRequestContext.getAuthorizationRequest().getScopes().contains(OidcScopes.OPENID)) { + // openid scope is auto-approved as it does not require consent + authorizedScopes.add(OidcScopes.OPENID); + } OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization()) .token(authorizationCode) .attributes(attrs -> { attrs.remove(OAuth2ParameterNames.STATE); - attrs.put(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, userConsentRequestContext.getScopes()); + attrs.put(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes); }) .build(); this.authorizationService.save(authorization); @@ -661,6 +675,8 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationRequest.class.getName()); + Set scopes = new HashSet<>(authorizationRequest.getScopes()); + scopes.remove(OidcScopes.OPENID); // openid scope does not require consent String state = authorization.getAttribute( OAuth2ParameterNames.STATE); @@ -695,7 +711,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { builder.append(" "); builder.append(" "); - for (String scope : authorizationRequest.getScopes()) { + for (String scope : scopes) { builder.append("
"); builder.append(" "); builder.append(" "); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 5516166..25764b4 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -19,6 +19,9 @@ import java.security.Principal; import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Set; import org.junit.Before; @@ -306,6 +309,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(accessTokenContext.getAuthorizationGrant()).isEqualTo(authentication); assertThat(accessTokenContext.getHeaders()).isNotNull(); assertThat(accessTokenContext.getClaims()).isNotNull(); + Map claims = new HashMap<>(); + accessTokenContext.getClaims().claims(claims::putAll); + assertThat(claims.containsKey(OidcScopes.OPENID)).isFalse(); // ID Token context JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1); assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient); @@ -328,8 +334,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken()); - assertThat(accessTokenAuthentication.getAccessToken().getScopes()) - .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)); + Set accessTokenScopes = new HashSet<>(updatedAuthorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)); + accessTokenScopes.remove(OidcScopes.OPENID); + assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(accessTokenScopes); assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull(); assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken()); OAuth2Authorization.Token authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 7cb82ad..385cc91 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -51,6 +51,7 @@ import org.springframework.security.oauth2.server.authorization.TestOAuth2Author import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.config.ClientSettings; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.util.StringUtils; @@ -445,6 +446,19 @@ public class OAuth2AuthorizationEndpointFilterTests { doFilterWhenAuthorizationRequestThenAuthorizationResponse(registeredClient, request); } + @Test + public void doFilterWhenAuthenticationRequestIncludesOnlyOpenidScopeThenDoesNotRequireConsent() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .scopes(scopes -> { + scopes.clear(); + scopes.add(OidcScopes.OPENID); + }) + .clientSettings(ClientSettings::requireUserConsent) + .build(); + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + doFilterWhenAuthorizationRequestThenAuthorizationResponse(registeredClient, request); + } + private void doFilterWhenAuthorizationRequestThenAuthorizationResponse( RegisteredClient registeredClient, MockHttpServletRequest request) throws Exception { @@ -772,11 +786,12 @@ public class OAuth2AuthorizationEndpointFilterTests { @Test public void doFilterWhenUserConsentRequestApprovedThenAuthorizationResponse() throws Exception { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName(this.authentication.getName()) + .attributes(attrs -> attrs.remove(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)) .build(); when(this.authorizationService.findByToken(eq("state"), eq(STATE_TOKEN_TYPE))) .thenReturn(authorization); @@ -908,7 +923,9 @@ public class OAuth2AuthorizationEndpointFilterTests { request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); request.addParameter(OAuth2ParameterNames.STATE, "state"); for (String scope : registeredClient.getScopes()) { - request.addParameter(OAuth2ParameterNames.SCOPE, scope); + if (!OidcScopes.OPENID.equals(scope)) { + request.addParameter(OAuth2ParameterNames.SCOPE, scope); + } } request.addParameter("consent_action", "approve");