From cee5aacc152f76928e78bd0ca2a7eee794d51270 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 8 Feb 2021 20:47:14 -0500 Subject: [PATCH] Remove OAuth2AuthorizationAttributeNames.STATE Issue gh-213 --- .../InMemoryOAuth2AuthorizationService.java | 5 +++-- .../OAuth2AuthorizationAttributeNames.java | 5 ----- .../OAuth2AuthorizationEndpointFilter.java | 8 +++---- ...MemoryOAuth2AuthorizationServiceTests.java | 5 +++-- ...Auth2AuthorizationEndpointFilterTests.java | 22 +++++++++---------- 5 files changed, 21 insertions(+), 24 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 0018886..d5ea75b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap; import org.springframework.lang.Nullable; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.util.Assert; @@ -72,7 +73,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza matchesAuthorizationCode(authorization, token) || matchesAccessToken(authorization, token) || matchesRefreshToken(authorization, token); - } else if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) { + } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) { return matchesState(authorization, token); } else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { return matchesAuthorizationCode(authorization, token); @@ -85,7 +86,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza } private static boolean matchesState(OAuth2Authorization authorization, String token) { - return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)); + return token.equals(authorization.getAttribute(OAuth2ParameterNames.STATE)); } private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java index 08c1da7..6e3f533 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java @@ -28,11 +28,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ */ public interface OAuth2AuthorizationAttributeNames { - /** - * The name of the attribute used for correlating the user consent request/response. - */ - String STATE = OAuth2Authorization.class.getName().concat(".STATE"); - /** * The name of the attribute used for the {@link OAuth2AuthorizationRequest}. */ diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index c3a5a2e..3103820 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -200,7 +200,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { if (registeredClient.getClientSettings().requireUserConsent()) { String state = this.stateGenerator.generateKey(); OAuth2Authorization authorization = builder - .attribute(OAuth2AuthorizationAttributeNames.STATE, state) + .attribute(OAuth2ParameterNames.STATE, state) .build(); this.authorizationService.save(authorization); @@ -266,7 +266,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization()) .token(authorizationCode) .attributes(attrs -> { - attrs.remove(OAuth2AuthorizationAttributeNames.STATE); + attrs.remove(OAuth2ParameterNames.STATE); attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes()); }) .build(); @@ -376,7 +376,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { return; } OAuth2Authorization authorization = this.authorizationService.findByToken( - userConsentRequestContext.getState(), new TokenType(OAuth2AuthorizationAttributeNames.STATE)); + userConsentRequestContext.getState(), new TokenType(OAuth2ParameterNames.STATE)); if (authorization == null) { userConsentRequestContext.setError( createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE)); @@ -661,7 +661,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); String state = authorization.getAttribute( - OAuth2AuthorizationAttributeNames.STATE); + OAuth2ParameterNames.STATE); StringBuilder builder = new StringBuilder(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index b263fd7..ff058f5 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -24,6 +24,7 @@ import org.junit.Test; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; @@ -110,12 +111,12 @@ public class InMemoryOAuth2AuthorizationServiceTests { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) - .attribute(OAuth2AuthorizationAttributeNames.STATE, state) + .attribute(OAuth2ParameterNames.STATE, state) .build(); this.authorizationService.save(authorization); OAuth2Authorization result = this.authorizationService.findByToken( - state, new TokenType(OAuth2AuthorizationAttributeNames.STATE)); + state, new TokenType(OAuth2ParameterNames.STATE)); assertThat(authorization).isEqualTo(result); result = this.authorizationService.findByToken(state, null); assertThat(authorization).isEqualTo(result); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index d75dca7..211b4d8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -569,7 +569,7 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) .isEqualTo(this.authentication); - String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE); + String state = authorization.getAttribute(OAuth2ParameterNames.STATE); assertThat(state).isNotNull(); Set authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); @@ -620,7 +620,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); this.authentication.setAuthenticated(false); @@ -638,7 +638,7 @@ public class OAuth2AuthorizationEndpointFilterTests { when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); this.authentication = new TestingAuthenticationToken("other-principal", "password"); @@ -662,7 +662,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName(this.authentication.getName()) .build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); doFilterWhenUserConsentRequestInvalidParameterThenError( @@ -680,7 +680,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName(this.authentication.getName()) .build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); doFilterWhenUserConsentRequestInvalidParameterThenError( @@ -698,7 +698,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName(this.authentication.getName()) .build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); doFilterWhenUserConsentRequestInvalidParameterThenError( @@ -717,7 +717,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(otherRegisteredClient) .principalName(this.authentication.getName()) .build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); doFilterWhenUserConsentRequestInvalidParameterThenError( @@ -735,7 +735,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName(this.authentication.getName()) .build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); doFilterWhenUserConsentRequestInvalidParameterThenRedirect( @@ -756,7 +756,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName(this.authentication.getName()) .build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); doFilterWhenUserConsentRequestInvalidParameterThenRedirect( @@ -777,7 +777,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName(this.authentication.getName()) .build(); - when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) + when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE)))) .thenReturn(authorization); MockHttpServletRequest request = createUserConsentRequest(registeredClient); @@ -800,7 +800,7 @@ public class OAuth2AuthorizationEndpointFilterTests { assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(updatedAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull(); - assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); + assertThat(updatedAuthorization.getAttribute(OAuth2ParameterNames.STATE)).isNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)) .isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)); assertThat(updatedAuthorization.>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES))