Ignore unknown token_type_hint

Closes gh-174
This commit is contained in:
Joe Grandja 2020-12-08 07:57:23 -05:00
parent f077337e43
commit 7f8aff7982
4 changed files with 43 additions and 32 deletions

View File

@ -63,23 +63,43 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
.orElse(null); .orElse(null);
} }
private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) { private static boolean hasToken(OAuth2Authorization authorization, String token, @Nullable TokenType tokenType) {
if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) { if (tokenType == null) {
return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)); return matchesState(authorization, token) ||
matchesAuthorizationCode(authorization, token) ||
matchesAccessToken(authorization, token) ||
matchesRefreshToken(authorization, token);
} else if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) {
return matchesState(authorization, token);
} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { } else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); return matchesAuthorizationCode(authorization, token);
return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) { } else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
return authorization.getTokens().getAccessToken() != null && return matchesAccessToken(authorization, token);
authorization.getTokens().getAccessToken().getTokenValue().equals(token);
} else if (TokenType.REFRESH_TOKEN.equals(tokenType)) { } else if (TokenType.REFRESH_TOKEN.equals(tokenType)) {
return authorization.getTokens().getRefreshToken() != null && return matchesRefreshToken(authorization, token);
authorization.getTokens().getRefreshToken().getTokenValue().equals(token);
} }
return false; return false;
} }
private static boolean matchesState(OAuth2Authorization authorization, String token) {
return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE));
}
private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) {
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
}
private static boolean matchesAccessToken(OAuth2Authorization authorization, String token) {
return authorization.getTokens().getAccessToken() != null &&
authorization.getTokens().getAccessToken().getTokenValue().equals(token);
}
private static boolean matchesRefreshToken(OAuth2Authorization authorization, String token) {
return authorization.getTokens().getRefreshToken() != null &&
authorization.getTokens().getRefreshToken().getTokenValue().equals(token);
}
private static class OAuth2AuthorizationId implements Serializable { private static class OAuth2AuthorizationId implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private final String registeredClientId; private final String registeredClientId;

View File

@ -22,7 +22,6 @@ import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes2;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
@ -71,8 +70,6 @@ public class OAuth2TokenRevocationAuthenticationProvider implements Authenticati
tokenType = TokenType.REFRESH_TOKEN; tokenType = TokenType.REFRESH_TOKEN;
} else if (TokenType.ACCESS_TOKEN.getValue().equals(tokenTypeHint)) { } else if (TokenType.ACCESS_TOKEN.getValue().equals(tokenTypeHint)) {
tokenType = TokenType.ACCESS_TOKEN; tokenType = TokenType.ACCESS_TOKEN;
} else {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes2.UNSUPPORTED_TOKEN_TYPE));
} }
} }

View File

@ -101,7 +101,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
} }
@Test @Test
public void findByTokenWhenTokenTypeStateThenFound() { public void findByTokenWhenStateExistsThenFound() {
String state = "state"; String state = "state";
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
@ -112,10 +112,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
OAuth2Authorization result = this.authorizationService.findByToken( OAuth2Authorization result = this.authorizationService.findByToken(
state, new TokenType(OAuth2AuthorizationAttributeNames.STATE)); state, new TokenType(OAuth2AuthorizationAttributeNames.STATE));
assertThat(authorization).isEqualTo(result); assertThat(authorization).isEqualTo(result);
result = this.authorizationService.findByToken(state, null);
assertThat(authorization).isEqualTo(result);
} }
@Test @Test
public void findByTokenWhenTokenTypeAuthorizationCodeThenFound() { public void findByTokenWhenAuthorizationCodeExistsThenFound() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build()) .tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
@ -125,10 +127,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
OAuth2Authorization result = this.authorizationService.findByToken( OAuth2Authorization result = this.authorizationService.findByToken(
AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE); AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
assertThat(authorization).isEqualTo(result); assertThat(authorization).isEqualTo(result);
result = this.authorizationService.findByToken(AUTHORIZATION_CODE.getTokenValue(), null);
assertThat(authorization).isEqualTo(result);
} }
@Test @Test
public void findByTokenWhenTokenTypeAccessTokenThenFound() { public void findByTokenWhenAccessTokenExistsThenFound() {
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"access-token", Instant.now().minusSeconds(60), Instant.now()); "access-token", Instant.now().minusSeconds(60), Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
@ -138,12 +142,14 @@ public class InMemoryOAuth2AuthorizationServiceTests {
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
OAuth2Authorization result = this.authorizationService.findByToken( OAuth2Authorization result = this.authorizationService.findByToken(
"access-token", TokenType.ACCESS_TOKEN); accessToken.getTokenValue(), TokenType.ACCESS_TOKEN);
assertThat(authorization).isEqualTo(result);
result = this.authorizationService.findByToken(accessToken.getTokenValue(), null);
assertThat(authorization).isEqualTo(result); assertThat(authorization).isEqualTo(result);
} }
@Test @Test
public void findByTokenWhenTokenTypeRefreshTokenThenFound() { public void findByTokenWhenRefreshTokenExistsThenFound() {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
@ -154,6 +160,8 @@ public class InMemoryOAuth2AuthorizationServiceTests {
OAuth2Authorization result = this.authorizationService.findByToken( OAuth2Authorization result = this.authorizationService.findByToken(
refreshToken.getTokenValue(), TokenType.REFRESH_TOKEN); refreshToken.getTokenValue(), TokenType.REFRESH_TOKEN);
assertThat(authorization).isEqualTo(result); assertThat(authorization).isEqualTo(result);
result = this.authorizationService.findByToken(refreshToken.getTokenValue(), null);
assertThat(authorization).isEqualTo(result);
} }
@Test @Test

View File

@ -23,7 +23,6 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes2;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@ -97,19 +96,6 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
} }
@Test
public void authenticateWhenInvalidTokenTypeThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
"token", clientPrincipal, OAuth2ErrorCodes2.UNSUPPORTED_TOKEN_TYPE);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes2.UNSUPPORTED_TOKEN_TYPE);
}
@Test @Test
public void authenticateWhenInvalidTokenThenNotRevoked() { public void authenticateWhenInvalidTokenThenNotRevoked() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();