Remove OAuth2AuthorizationAttributeNames.STATE

Issue gh-213
This commit is contained in:
Joe Grandja 2021-02-08 20:47:14 -05:00
parent fd9df9e2e7
commit cee5aacc15
5 changed files with 21 additions and 24 deletions

View File

@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -72,7 +73,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
matchesAuthorizationCode(authorization, token) || matchesAuthorizationCode(authorization, token) ||
matchesAccessToken(authorization, token) || matchesAccessToken(authorization, token) ||
matchesRefreshToken(authorization, token); matchesRefreshToken(authorization, token);
} else if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) { } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
return matchesState(authorization, token); return matchesState(authorization, token);
} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { } else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
return matchesAuthorizationCode(authorization, token); return matchesAuthorizationCode(authorization, token);
@ -85,7 +86,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
} }
private static boolean matchesState(OAuth2Authorization authorization, String token) { private static boolean matchesState(OAuth2Authorization authorization, String token) {
return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)); return token.equals(authorization.getAttribute(OAuth2ParameterNames.STATE));
} }
private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) { private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) {

View File

@ -28,11 +28,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
*/ */
public interface OAuth2AuthorizationAttributeNames { public interface OAuth2AuthorizationAttributeNames {
/**
* The name of the attribute used for correlating the user consent request/response.
*/
String STATE = OAuth2Authorization.class.getName().concat(".STATE");
/** /**
* The name of the attribute used for the {@link OAuth2AuthorizationRequest}. * The name of the attribute used for the {@link OAuth2AuthorizationRequest}.
*/ */

View File

@ -200,7 +200,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
if (registeredClient.getClientSettings().requireUserConsent()) { if (registeredClient.getClientSettings().requireUserConsent()) {
String state = this.stateGenerator.generateKey(); String state = this.stateGenerator.generateKey();
OAuth2Authorization authorization = builder OAuth2Authorization authorization = builder
.attribute(OAuth2AuthorizationAttributeNames.STATE, state) .attribute(OAuth2ParameterNames.STATE, state)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
@ -266,7 +266,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization()) OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization())
.token(authorizationCode) .token(authorizationCode)
.attributes(attrs -> { .attributes(attrs -> {
attrs.remove(OAuth2AuthorizationAttributeNames.STATE); attrs.remove(OAuth2ParameterNames.STATE);
attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes()); attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes());
}) })
.build(); .build();
@ -376,7 +376,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
return; return;
} }
OAuth2Authorization authorization = this.authorizationService.findByToken( OAuth2Authorization authorization = this.authorizationService.findByToken(
userConsentRequestContext.getState(), new TokenType(OAuth2AuthorizationAttributeNames.STATE)); userConsentRequestContext.getState(), new TokenType(OAuth2ParameterNames.STATE));
if (authorization == null) { if (authorization == null) {
userConsentRequestContext.setError( userConsentRequestContext.setError(
createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE)); createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE));
@ -661,7 +661,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
String state = authorization.getAttribute( String state = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.STATE); OAuth2ParameterNames.STATE);
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();

View File

@ -24,6 +24,7 @@ import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
@ -110,12 +111,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.attribute(OAuth2AuthorizationAttributeNames.STATE, state) .attribute(OAuth2ParameterNames.STATE, state)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
OAuth2Authorization result = this.authorizationService.findByToken( OAuth2Authorization result = this.authorizationService.findByToken(
state, new TokenType(OAuth2AuthorizationAttributeNames.STATE)); state, new TokenType(OAuth2ParameterNames.STATE));
assertThat(authorization).isEqualTo(result); assertThat(authorization).isEqualTo(result);
result = this.authorizationService.findByToken(state, null); result = this.authorizationService.findByToken(state, null);
assertThat(authorization).isEqualTo(result); assertThat(authorization).isEqualTo(result);

View File

@ -569,7 +569,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication); .isEqualTo(this.authentication);
String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE); String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
assertThat(state).isNotNull(); assertThat(state).isNotNull();
Set<String> authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES); Set<String> authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES);
@ -620,7 +620,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient); .thenReturn(registeredClient);
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
this.authentication.setAuthenticated(false); this.authentication.setAuthenticated(false);
@ -638,7 +638,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient); .thenReturn(registeredClient);
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
this.authentication = new TestingAuthenticationToken("other-principal", "password"); this.authentication = new TestingAuthenticationToken("other-principal", "password");
@ -662,7 +662,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.principalName(this.authentication.getName()) .principalName(this.authentication.getName())
.build(); .build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
doFilterWhenUserConsentRequestInvalidParameterThenError( doFilterWhenUserConsentRequestInvalidParameterThenError(
@ -680,7 +680,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.principalName(this.authentication.getName()) .principalName(this.authentication.getName())
.build(); .build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
doFilterWhenUserConsentRequestInvalidParameterThenError( doFilterWhenUserConsentRequestInvalidParameterThenError(
@ -698,7 +698,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.principalName(this.authentication.getName()) .principalName(this.authentication.getName())
.build(); .build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
doFilterWhenUserConsentRequestInvalidParameterThenError( doFilterWhenUserConsentRequestInvalidParameterThenError(
@ -717,7 +717,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(otherRegisteredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(otherRegisteredClient)
.principalName(this.authentication.getName()) .principalName(this.authentication.getName())
.build(); .build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
doFilterWhenUserConsentRequestInvalidParameterThenError( doFilterWhenUserConsentRequestInvalidParameterThenError(
@ -735,7 +735,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.principalName(this.authentication.getName()) .principalName(this.authentication.getName())
.build(); .build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
doFilterWhenUserConsentRequestInvalidParameterThenRedirect( doFilterWhenUserConsentRequestInvalidParameterThenRedirect(
@ -756,7 +756,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.principalName(this.authentication.getName()) .principalName(this.authentication.getName())
.build(); .build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
doFilterWhenUserConsentRequestInvalidParameterThenRedirect( doFilterWhenUserConsentRequestInvalidParameterThenRedirect(
@ -777,7 +777,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.principalName(this.authentication.getName()) .principalName(this.authentication.getName())
.build(); .build();
when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE)))) when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
.thenReturn(authorization); .thenReturn(authorization);
MockHttpServletRequest request = createUserConsentRequest(registeredClient); MockHttpServletRequest request = createUserConsentRequest(registeredClient);
@ -800,7 +800,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(updatedAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(updatedAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull(); assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull();
assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); assertThat(updatedAuthorization.<String>getAttribute(OAuth2ParameterNames.STATE)).isNull();
assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)) assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))
.isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)); .isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST));
assertThat(updatedAuthorization.<Set<String>>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES)) assertThat(updatedAuthorization.<Set<String>>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES))