Improve OAuth2Authorization model

This commit removes OAuth2Tokens and OAuth2TokenMetadata and consolidates the code into OAuth2Authorization.

Closes gh-213
This commit is contained in:
Joe Grandja 2021-02-05 13:20:17 -05:00
parent 218d49b134
commit bffcbc5440
24 changed files with 394 additions and 972 deletions

View File

@ -15,15 +15,17 @@
*/ */
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.util.Assert;
import java.io.Serializable; import java.io.Serializable;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.util.Assert;
/** /**
* An {@link OAuth2AuthorizationService} that stores {@link OAuth2Authorization}'s in-memory. * An {@link OAuth2AuthorizationService} that stores {@link OAuth2Authorization}'s in-memory.
* *
@ -87,18 +89,21 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
} }
private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) { private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) {
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
return authorizationCode != null && authorizationCode.getTokenValue().equals(token); authorization.getToken(OAuth2AuthorizationCode.class);
return authorizationCode != null && authorizationCode.getToken().getTokenValue().equals(token);
} }
private static boolean matchesAccessToken(OAuth2Authorization authorization, String token) { private static boolean matchesAccessToken(OAuth2Authorization authorization, String token) {
return authorization.getTokens().getAccessToken() != null && OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
authorization.getTokens().getAccessToken().getTokenValue().equals(token); authorization.getToken(OAuth2AccessToken.class);
return accessToken != null && accessToken.getToken().getTokenValue().equals(token);
} }
private static boolean matchesRefreshToken(OAuth2Authorization authorization, String token) { private static boolean matchesRefreshToken(OAuth2Authorization authorization, String token) {
return authorization.getTokens().getRefreshToken() != null && OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
authorization.getTokens().getRefreshToken().getTokenValue().equals(token); authorization.getToken(OAuth2RefreshToken.class);
return refreshToken != null && refreshToken.getToken().getTokenValue().equals(token);
} }
private static class OAuth2AuthorizationId implements Serializable { private static class OAuth2AuthorizationId implements Serializable {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,11 +15,6 @@
*/ */
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -27,26 +22,32 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
/** /**
* A representation of an OAuth 2.0 Authorization, * A representation of an OAuth 2.0 Authorization, which holds state related to the authorization granted
* which holds state related to the authorization granted to the {@link #getRegisteredClientId() client} * to a {@link #getRegisteredClientId() client}, by the {@link #getPrincipalName() resource owner}
* by the {@link #getPrincipalName() resource owner}. * or itself in the case of the {@code client_credentials} grant type.
* *
* @author Joe Grandja * @author Joe Grandja
* @author Krisztian Toth * @author Krisztian Toth
* @since 0.0.1 * @since 0.0.1
* @see RegisteredClient * @see RegisteredClient
* @see OAuth2Tokens * @see AbstractOAuth2Token
* @see OAuth2AccessToken
* @see OAuth2RefreshToken
*/ */
public class OAuth2Authorization implements Serializable { public class OAuth2Authorization implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private String registeredClientId; private String registeredClientId;
private String principalName; private String principalName;
private OAuth2Tokens tokens; private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens;
@Deprecated
private OAuth2AccessToken accessToken;
private Map<String, Object> attributes; private Map<String, Object> attributes;
protected OAuth2Authorization() { protected OAuth2Authorization() {
@ -62,31 +63,64 @@ public class OAuth2Authorization implements Serializable {
} }
/** /**
* Returns the resource owner's {@code Principal} name. * Returns the {@code Principal} name of the resource owner (or client).
* *
* @return the resource owner's {@code Principal} name * @return the {@code Principal} name of the resource owner (or client)
*/ */
public String getPrincipalName() { public String getPrincipalName() {
return this.principalName; return this.principalName;
} }
/** /**
* Returns the {@link OAuth2Tokens}. * Returns the {@link Token} of type {@link OAuth2AccessToken}.
* *
* @return the {@link OAuth2Tokens} * @return the {@link Token} of type {@link OAuth2AccessToken}
*/ */
public OAuth2Tokens getTokens() { public Token<OAuth2AccessToken> getAccessToken() {
return this.tokens; return getToken(OAuth2AccessToken.class);
} }
/** /**
* Returns the {@link OAuth2AccessToken access token} credential. * Returns the {@link Token} of type {@link OAuth2RefreshToken}.
* *
* @return the {@link OAuth2AccessToken} * @return the {@link Token} of type {@link OAuth2RefreshToken}, or {@code null} if not available
*/ */
@Deprecated @Nullable
public OAuth2AccessToken getAccessToken() { public Token<OAuth2RefreshToken> getRefreshToken() {
return getTokens().getAccessToken(); return getToken(OAuth2RefreshToken.class);
}
/**
* Returns the {@link Token} of type {@code tokenType}.
*
* @param tokenType the token type
* @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends AbstractOAuth2Token> Token<T> getToken(Class<T> tokenType) {
Assert.notNull(tokenType, "tokenType cannot be null");
Token<?> token = this.tokens.get(tokenType);
return token != null ? (Token<T>) token : null;
}
/**
* Returns the {@link Token} matching the {@code tokenValue}.
*
* @param tokenValue the token value
* @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends AbstractOAuth2Token> Token<T> getToken(String tokenValue) {
Assert.hasText(tokenValue, "tokenValue cannot be empty");
Token<?> token = this.tokens.values().stream()
.filter(t -> t.getToken().getTokenValue().equals(tokenValue))
.findFirst()
.orElse(null);
return token != null ? (Token<T>) token : null;
} }
/** /**
@ -103,8 +137,9 @@ public class OAuth2Authorization implements Serializable {
* *
* @param name the name of the attribute * @param name the name of the attribute
* @param <T> the type of the attribute * @param <T> the type of the attribute
* @return the value of the attribute associated to the authorization, or {@code null} if not available * @return the value of an attribute associated to the authorization, or {@code null} if not available
*/ */
@Nullable
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> T getAttribute(String name) { public <T> T getAttribute(String name) {
Assert.hasText(name, "name cannot be empty"); Assert.hasText(name, "name cannot be empty");
@ -143,41 +178,131 @@ public class OAuth2Authorization implements Serializable {
} }
/** /**
* Returns a new {@link Builder}, initialized with the values from the provided {@code authorization}. * Returns a new {@link Builder}, initialized with the values from the provided {@code OAuth2Authorization}.
* *
* @param authorization the authorization used for initializing the {@link Builder} * @param authorization the {@code OAuth2Authorization} used for initializing the {@link Builder}
* @return the {@link Builder} * @return the {@link Builder}
*/ */
public static Builder from(OAuth2Authorization authorization) { public static Builder from(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null"); Assert.notNull(authorization, "authorization cannot be null");
return new Builder(authorization.getRegisteredClientId()) return new Builder(authorization.getRegisteredClientId())
.principalName(authorization.getPrincipalName()) .principalName(authorization.getPrincipalName())
.tokens(OAuth2Tokens.from(authorization.getTokens()).build()) .tokens(authorization.tokens)
.attributes(attrs -> attrs.putAll(authorization.getAttributes())); .attributes(attrs -> attrs.putAll(authorization.getAttributes()));
} }
/**
* A holder of an OAuth 2.0 Token and it's associated metadata.
*
* @author Joe Grandja
* @since 0.1.0
*/
public static class Token<T extends AbstractOAuth2Token> implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
protected static final String TOKEN_METADATA_BASE = "metadata.token.";
/**
* The name of the metadata that indicates if the token has been invalidated.
*/
public static final String INVALIDATED_METADATA_NAME = TOKEN_METADATA_BASE.concat("invalidated");
private final T token;
private final Map<String, Object> metadata;
protected Token(T token) {
this(token, defaultMetadata());
}
protected Token(T token, Map<String, Object> metadata) {
this.token = token;
this.metadata = Collections.unmodifiableMap(metadata);
}
/**
* Returns the token of type {@link AbstractOAuth2Token}.
*
* @return the token of type {@link AbstractOAuth2Token}
*/
public T getToken() {
return this.token;
}
/**
* Returns {@code true} if the token has been invalidated (e.g. revoked).
* The default is {@code false}.
*
* @return {@code true} if the token has been invalidated, {@code false} otherwise
*/
public boolean isInvalidated() {
return Boolean.TRUE.equals(getMetadata(INVALIDATED_METADATA_NAME));
}
/**
* Returns the value of the metadata associated to the token.
*
* @param name the name of the metadata
* @param <V> the value type of the metadata
* @return the value of the metadata, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <V> V getMetadata(String name) {
Assert.hasText(name, "name cannot be empty");
return (V) this.metadata.get(name);
}
/**
* Returns the metadata associated to the token.
*
* @return a {@code Map} of the metadata
*/
public Map<String, Object> getMetadata() {
return this.metadata;
}
protected static Map<String, Object> defaultMetadata() {
Map<String, Object> metadata = new HashMap<>();
metadata.put(INVALIDATED_METADATA_NAME, false);
return metadata;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Token<?> that = (Token<?>) obj;
return Objects.equals(this.token, that.token) &&
Objects.equals(this.metadata, that.metadata);
}
@Override
public int hashCode() {
return Objects.hash(this.token, this.metadata);
}
}
/** /**
* A builder for {@link OAuth2Authorization}. * A builder for {@link OAuth2Authorization}.
*/ */
public static class Builder implements Serializable { public static class Builder implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private String registeredClientId; private final String registeredClientId;
private String principalName; private String principalName;
private OAuth2Tokens tokens; private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens = new HashMap<>();
private final Map<String, Object> attributes = new HashMap<>();
@Deprecated
private OAuth2AccessToken accessToken;
private Map<String, Object> attributes = new HashMap<>();
protected Builder(String registeredClientId) { protected Builder(String registeredClientId) {
this.registeredClientId = registeredClientId; this.registeredClientId = registeredClientId;
} }
/** /**
* Sets the resource owner's {@code Principal} name. * Sets the {@code Principal} name of the resource owner (or client).
* *
* @param principalName the resource owner's {@code Principal} name * @param principalName the {@code Principal} name of the resource owner (or client)
* @return the {@link Builder} * @return the {@link Builder}
*/ */
public Builder principalName(String principalName) { public Builder principalName(String principalName) {
@ -186,25 +311,60 @@ public class OAuth2Authorization implements Serializable {
} }
/** /**
* Sets the {@link OAuth2Tokens}. * Sets the {@link OAuth2AccessToken access token}.
*
* @param tokens the {@link OAuth2Tokens}
* @return the {@link Builder}
*/
public Builder tokens(OAuth2Tokens tokens) {
this.tokens = tokens;
return this;
}
/**
* Sets the {@link OAuth2AccessToken access token} credential.
* *
* @param accessToken the {@link OAuth2AccessToken} * @param accessToken the {@link OAuth2AccessToken}
* @return the {@link Builder} * @return the {@link Builder}
*/ */
@Deprecated
public Builder accessToken(OAuth2AccessToken accessToken) { public Builder accessToken(OAuth2AccessToken accessToken) {
this.accessToken = accessToken; return token(accessToken);
}
/**
* Sets the {@link OAuth2RefreshToken refresh token}.
*
* @param refreshToken the {@link OAuth2RefreshToken}
* @return the {@link Builder}
*/
public Builder refreshToken(OAuth2RefreshToken refreshToken) {
return token(refreshToken);
}
/**
* Sets the {@link AbstractOAuth2Token token}.
*
* @param token the token
* @param <T> the type of the token
* @return the {@link Builder}
*/
public <T extends AbstractOAuth2Token> Builder token(T token) {
return token(token, (metadata) -> {});
}
/**
* Sets the {@link AbstractOAuth2Token token} and associated metadata.
*
* @param token the token
* @param metadataConsumer a {@code Consumer} of the metadata {@code Map}
* @param <T> the type of the token
* @return the {@link Builder}
*/
public <T extends AbstractOAuth2Token> Builder token(T token,
Consumer<Map<String, Object>> metadataConsumer) {
Assert.notNull(token, "token cannot be null");
Map<String, Object> metadata = Token.defaultMetadata();
metadataConsumer.accept(metadata);
Class<? extends AbstractOAuth2Token> tokenClass = token.getClass();
if (tokenClass.equals(OAuth2RefreshToken2.class)) {
tokenClass = OAuth2RefreshToken.class;
}
this.tokens.put(tokenClass, new Token<>(token, metadata));
return this;
}
protected final Builder tokens(Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens) {
this.tokens = new HashMap<>(tokens);
return this; return this;
} }
@ -245,14 +405,7 @@ public class OAuth2Authorization implements Serializable {
OAuth2Authorization authorization = new OAuth2Authorization(); OAuth2Authorization authorization = new OAuth2Authorization();
authorization.registeredClientId = this.registeredClientId; authorization.registeredClientId = this.registeredClientId;
authorization.principalName = this.principalName; authorization.principalName = this.principalName;
if (this.tokens == null) { authorization.tokens = Collections.unmodifiableMap(this.tokens);
OAuth2Tokens.Builder builder = OAuth2Tokens.builder();
if (this.accessToken != null) {
builder.accessToken(this.accessToken);
}
this.tokens = builder.build();
}
authorization.tokens = this.tokens;
authorization.attributes = Collections.unmodifiableMap(this.attributes); authorization.attributes = Collections.unmodifiableMap(this.attributes);
return authorization; return authorization;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,8 +24,6 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
/** /**
* Utility methods for the OAuth 2.0 {@link AuthenticationProvider}'s. * Utility methods for the OAuth 2.0 {@link AuthenticationProvider}'s.
@ -52,25 +50,29 @@ final class OAuth2AuthenticationProviderUtils {
static <T extends AbstractOAuth2Token> OAuth2Authorization invalidate( static <T extends AbstractOAuth2Token> OAuth2Authorization invalidate(
OAuth2Authorization authorization, T token) { OAuth2Authorization authorization, T token) {
OAuth2Tokens.Builder builder = OAuth2Tokens.from(authorization.getTokens()) // @formatter:off
.token(token, OAuth2TokenMetadata.builder().invalidated().build()); OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
.token(token,
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) { if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
builder.token( authorizationBuilder.token(
authorization.getTokens().getAccessToken(), authorization.getAccessToken().getToken(),
OAuth2TokenMetadata.builder().invalidated().build()); (metadata) ->
OAuth2AuthorizationCode authorizationCode = metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
if (authorizationCode != null && OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
!authorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()) { authorization.getToken(OAuth2AuthorizationCode.class);
builder.token( if (authorizationCode != null && !authorizationCode.isInvalidated()) {
authorizationCode, authorizationBuilder.token(
OAuth2TokenMetadata.builder().invalidated().build()); authorizationCode.getToken(),
(metadata) ->
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
} }
} }
// @formatter:on
return OAuth2Authorization.from(authorization) return authorizationBuilder.build();
.tokens(builder.build())
.build();
} }
} }

View File

@ -37,16 +37,14 @@ import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -104,16 +102,16 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
if (authorization == null) { if (authorization == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode); authorization.getToken(OAuth2AuthorizationCode.class);
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) { if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
if (!authorizationCodeMetadata.isInvalidated()) { if (!authorizationCode.isInvalidated()) {
// Invalidate the authorization code given that a different client is attempting to use it // Invalidate the authorization code given that a different client is attempting to use it
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode); authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken());
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
} }
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
@ -124,7 +122,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
if (authorizationCodeMetadata.isInvalidated()) { if (authorizationCode.isInvalidated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
@ -143,14 +141,11 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens())
.accessToken(accessToken);
OAuth2RefreshToken refreshToken = null; OAuth2RefreshToken refreshToken = null;
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) { if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
refreshToken = OAuth2RefreshTokenAuthenticationProvider.generateRefreshToken( refreshToken = OAuth2RefreshTokenAuthenticationProvider.generateRefreshToken(
registeredClient.getTokenSettings().refreshTokenTimeToLive()); registeredClient.getTokenSettings().refreshTokenTimeToLive());
tokensBuilder.refreshToken(refreshToken);
} }
OidcIdToken idToken = null; OidcIdToken idToken = null;
@ -170,17 +165,21 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(), idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
jwtIdToken.getExpiresAt(), jwtIdToken.getClaims()); jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
tokensBuilder.token(idToken);
} }
OAuth2Tokens tokens = tokensBuilder.build(); OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
authorization = OAuth2Authorization.from(authorization) .accessToken(accessToken)
.tokens(tokens) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt);
.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) if (refreshToken != null) {
.build(); authorizationBuilder.refreshToken(refreshToken);
}
if (idToken != null) {
authorizationBuilder.token(idToken);
}
authorization = authorizationBuilder.build();
// Invalidate the authorization code as it can only be used once // Invalidate the authorization code as it can only be used once
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode); authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken());
this.authorizationService.save(authorization); this.authorizationService.save(authorization);

View File

@ -32,13 +32,12 @@ import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@ -125,7 +124,7 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName(clientPrincipal.getName()) .principalName(clientPrincipal.getName())
.tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) .token(accessToken)
.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);

View File

@ -37,16 +37,14 @@ import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.config.TokenSettings; import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
@ -114,7 +112,8 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT));
} }
Instant refreshTokenExpiresAt = authorization.getTokens().getRefreshToken().getExpiresAt(); OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
Instant refreshTokenExpiresAt = refreshToken.getToken().getExpiresAt();
if (refreshTokenExpiresAt.isBefore(Instant.now())) { if (refreshTokenExpiresAt.isBefore(Instant.now())) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2 // As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization code, // invalid_grant: The provided authorization grant (e.g., authorization code,
@ -134,10 +133,7 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
scopes = authorizedScopes; scopes = authorizedScopes;
} }
OAuth2RefreshToken refreshToken = authorization.getTokens().getRefreshToken(); if (refreshToken.isInvalidated()) {
OAuth2TokenMetadata refreshTokenMetadata = authorization.getTokens().getTokenMetadata(refreshToken);
if (refreshTokenMetadata.isInvalidated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
} }
@ -159,18 +155,20 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
TokenSettings tokenSettings = registeredClient.getTokenSettings(); TokenSettings tokenSettings = registeredClient.getTokenSettings();
OAuth2RefreshToken currentRefreshToken = refreshToken.getToken();
if (!tokenSettings.reuseRefreshTokens()) { if (!tokenSettings.reuseRefreshTokens()) {
refreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive()); currentRefreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive());
} }
authorization = OAuth2Authorization.from(authorization) authorization = OAuth2Authorization.from(authorization)
.tokens(OAuth2Tokens.from(authorization.getTokens()).accessToken(accessToken).refreshToken(refreshToken).build()) .accessToken(accessToken)
.refreshToken(currentRefreshToken)
.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
return new OAuth2AccessTokenAuthenticationToken( return new OAuth2AccessTokenAuthenticationToken(
registeredClient, clientPrincipal, accessToken, refreshToken); registeredClient, clientPrincipal, accessToken, currentRefreshToken);
} }
@Override @Override

View File

@ -72,11 +72,11 @@ public class OAuth2TokenRevocationAuthenticationProvider implements Authenticati
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT)); throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
} }
AbstractOAuth2Token token = authorization.getTokens().getToken(tokenRevocationAuthentication.getToken()); OAuth2Authorization.Token<AbstractOAuth2Token> token = authorization.getToken(tokenRevocationAuthentication.getToken());
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token); authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
return new OAuth2TokenRevocationAuthenticationToken(token, clientPrincipal); return new OAuth2TokenRevocationAuthenticationToken(token.getToken(), clientPrincipal);
} }
@Override @Override

View File

@ -1,169 +0,0 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.util.Assert;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
/**
* Holds metadata associated to an OAuth 2.0 Token.
*
* @author Joe Grandja
* @since 0.0.3
* @see OAuth2Tokens
*/
public class OAuth2TokenMetadata implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
protected static final String TOKEN_METADATA_BASE = "metadata.token.";
/**
* The name of the metadata that indicates if the token has been invalidated.
*/
public static final String INVALIDATED = TOKEN_METADATA_BASE.concat("invalidated");
private final Map<String, Object> metadata;
protected OAuth2TokenMetadata(Map<String, Object> metadata) {
this.metadata = Collections.unmodifiableMap(new HashMap<>(metadata));
}
/**
* Returns {@code true} if the token has been invalidated (e.g. revoked).
* The default is {@code false}.
*
* @return {@code true} if the token has been invalidated, {@code false} otherwise
*/
public boolean isInvalidated() {
return getMetadata(INVALIDATED);
}
/**
* Returns the value of the metadata associated to the token.
*
* @param name the name of the metadata
* @param <T> the type of the metadata
* @return the value of the metadata, or {@code null} if not available
*/
@SuppressWarnings("unchecked")
public <T> T getMetadata(String name) {
Assert.hasText(name, "name cannot be empty");
return (T) this.metadata.get(name);
}
/**
* Returns the metadata associated to the token.
*
* @return a {@code Map} of the metadata
*/
public Map<String, Object> getMetadata() {
return this.metadata;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
OAuth2TokenMetadata that = (OAuth2TokenMetadata) obj;
return Objects.equals(this.metadata, that.metadata);
}
@Override
public int hashCode() {
return Objects.hash(this.metadata);
}
/**
* Returns a new {@link Builder}.
*
* @return the {@link Builder}
*/
public static Builder builder() {
return new Builder();
}
/**
* A builder for {@link OAuth2TokenMetadata}.
*/
public static class Builder implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private final Map<String, Object> metadata = defaultMetadata();
protected Builder() {
}
/**
* Set the token as invalidated (e.g. revoked).
*
* @return the {@link Builder}
*/
public Builder invalidated() {
metadata(INVALIDATED, true);
return this;
}
/**
* Adds a metadata associated to the token.
*
* @param name the name of the metadata
* @param value the value of the metadata
* @return the {@link Builder}
*/
public Builder metadata(String name, Object value) {
Assert.hasText(name, "name cannot be empty");
Assert.notNull(value, "value cannot be null");
this.metadata.put(name, value);
return this;
}
/**
* A {@code Consumer} of the metadata {@code Map}
* allowing the ability to add, replace, or remove.
*
* @param metadataConsumer a {@link Consumer} of the metadata {@code Map}
* @return the {@link Builder}
*/
public Builder metadata(Consumer<Map<String, Object>> metadataConsumer) {
metadataConsumer.accept(this.metadata);
return this;
}
/**
* Builds a new {@link OAuth2TokenMetadata}.
*
* @return the {@link OAuth2TokenMetadata}
*/
public OAuth2TokenMetadata build() {
return new OAuth2TokenMetadata(this.metadata);
}
protected static Map<String, Object> defaultMetadata() {
Map<String, Object> metadata = new HashMap<>();
metadata.put(INVALIDATED, false);
return metadata;
}
}
}

View File

@ -1,292 +0,0 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.util.Assert;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* A container for OAuth 2.0 Tokens.
*
* @author Joe Grandja
* @since 0.0.3
* @see OAuth2Authorization
* @see OAuth2TokenMetadata
* @see AbstractOAuth2Token
* @see OAuth2AccessToken
* @see OAuth2RefreshToken
*/
public class OAuth2Tokens implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private final Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens;
protected OAuth2Tokens(Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens) {
this.tokens = new HashMap<>(tokens);
}
/**
* Returns the {@link OAuth2AccessToken access token}.
*
* @return the {@link OAuth2AccessToken}, or {@code null} if not available
*/
@Nullable
public OAuth2AccessToken getAccessToken() {
return getToken(OAuth2AccessToken.class);
}
/**
* Returns the {@link OAuth2RefreshToken refresh token}.
*
* @return the {@link OAuth2RefreshToken}, or {@code null} if not available
*/
@Nullable
public OAuth2RefreshToken getRefreshToken() {
OAuth2RefreshToken refreshToken = getToken(OAuth2RefreshToken.class);
return refreshToken != null ? refreshToken : getToken(OAuth2RefreshToken2.class);
}
/**
* Returns the token specified by {@code tokenType}.
*
* @param tokenType the token type
* @param <T> the type of the token
* @return the token, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends AbstractOAuth2Token> T getToken(Class<T> tokenType) {
Assert.notNull(tokenType, "tokenType cannot be null");
OAuth2TokenHolder tokenHolder = this.tokens.get(tokenType);
return tokenHolder != null ? (T) tokenHolder.getToken() : null;
}
/**
* Returns the token specified by {@code token}.
*
* @param token the token
* @param <T> the type of the token
* @return the token, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends AbstractOAuth2Token> T getToken(String token) {
Assert.hasText(token, "token cannot be empty");
OAuth2TokenHolder tokenHolder = this.tokens.values().stream()
.filter(holder -> holder.getToken().getTokenValue().equals(token))
.findFirst()
.orElse(null);
return tokenHolder != null ? (T) tokenHolder.getToken() : null;
}
/**
* Returns the token metadata associated to the provided {@code token}.
*
* @param token the token
* @param <T> the type of the token
* @return the token metadata, or {@code null} if not available
*/
@Nullable
public <T extends AbstractOAuth2Token> OAuth2TokenMetadata getTokenMetadata(T token) {
Assert.notNull(token, "token cannot be null");
OAuth2TokenHolder tokenHolder = this.tokens.get(token.getClass());
return (tokenHolder != null && tokenHolder.getToken().equals(token)) ?
tokenHolder.getTokenMetadata() : null;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
OAuth2Tokens that = (OAuth2Tokens) obj;
return Objects.equals(this.tokens, that.tokens);
}
@Override
public int hashCode() {
return Objects.hash(this.tokens);
}
/**
* Returns a new {@link Builder}.
*
* @return the {@link Builder}
*/
public static Builder builder() {
return new Builder();
}
/**
* Returns a new {@link Builder}, initialized with the values from the provided {@code tokens}.
*
* @param tokens the tokens used for initializing the {@link Builder}
* @return the {@link Builder}
*/
public static Builder from(OAuth2Tokens tokens) {
Assert.notNull(tokens, "tokens cannot be null");
return new Builder(tokens.tokens);
}
/**
* A builder for {@link OAuth2Tokens}.
*/
public static class Builder implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens;
protected Builder() {
this.tokens = new HashMap<>();
}
protected Builder(Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens) {
this.tokens = new HashMap<>(tokens);
}
/**
* Sets the {@link OAuth2AccessToken access token}.
*
* @param accessToken the {@link OAuth2AccessToken}
* @return the {@link Builder}
*/
public Builder accessToken(OAuth2AccessToken accessToken) {
return addToken(accessToken, null);
}
/**
* Sets the {@link OAuth2AccessToken access token} and associated {@link OAuth2TokenMetadata token metadata}.
*
* @param accessToken the {@link OAuth2AccessToken}
* @param tokenMetadata the {@link OAuth2TokenMetadata}
* @return the {@link Builder}
*/
public Builder accessToken(OAuth2AccessToken accessToken, OAuth2TokenMetadata tokenMetadata) {
return addToken(accessToken, tokenMetadata);
}
/**
* Sets the {@link OAuth2RefreshToken refresh token}.
*
* @param refreshToken the {@link OAuth2RefreshToken}
* @return the {@link Builder}
*/
public Builder refreshToken(OAuth2RefreshToken refreshToken) {
return addToken(refreshToken, null);
}
/**
* Sets the {@link OAuth2RefreshToken refresh token} and associated {@link OAuth2TokenMetadata token metadata}.
*
* @param refreshToken the {@link OAuth2RefreshToken}
* @param tokenMetadata the {@link OAuth2TokenMetadata}
* @return the {@link Builder}
*/
public Builder refreshToken(OAuth2RefreshToken refreshToken, OAuth2TokenMetadata tokenMetadata) {
return addToken(refreshToken, tokenMetadata);
}
/**
* Sets the token.
*
* @param token the token
* @param <T> the type of the token
* @return the {@link Builder}
*/
public <T extends AbstractOAuth2Token> Builder token(T token) {
return addToken(token, null);
}
/**
* Sets the token and associated {@link OAuth2TokenMetadata token metadata}.
*
* @param token the token
* @param tokenMetadata the {@link OAuth2TokenMetadata}
* @param <T> the type of the token
* @return the {@link Builder}
*/
public <T extends AbstractOAuth2Token> Builder token(T token, OAuth2TokenMetadata tokenMetadata) {
return addToken(token, tokenMetadata);
}
protected Builder addToken(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) {
Assert.notNull(token, "token cannot be null");
if (tokenMetadata == null) {
tokenMetadata = OAuth2TokenMetadata.builder().build();
}
this.tokens.put(token.getClass(), new OAuth2TokenHolder(token, tokenMetadata));
return this;
}
/**
* Builds a new {@link OAuth2Tokens}.
*
* @return the {@link OAuth2Tokens}
*/
public OAuth2Tokens build() {
return new OAuth2Tokens(this.tokens);
}
}
protected static class OAuth2TokenHolder implements Serializable {
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private final AbstractOAuth2Token token;
private final OAuth2TokenMetadata tokenMetadata;
protected OAuth2TokenHolder(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) {
this.token = token;
this.tokenMetadata = tokenMetadata;
}
protected AbstractOAuth2Token getToken() {
return this.token;
}
protected OAuth2TokenMetadata getTokenMetadata() {
return this.tokenMetadata;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
OAuth2TokenHolder that = (OAuth2TokenHolder) obj;
return Objects.equals(this.token, that.token) &&
Objects.equals(this.tokenMetadata, that.tokenMetadata);
}
@Override
public int hashCode() {
return Objects.hash(this.token, this.tokenMetadata);
}
}
}

View File

@ -54,7 +54,6 @@ import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.matcher.AndRequestMatcher; import org.springframework.security.web.util.matcher.AndRequestMatcher;
@ -213,7 +212,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
this.codeGenerator.generateKey(), issuedAt, expiresAt); this.codeGenerator.generateKey(), issuedAt, expiresAt);
OAuth2Authorization authorization = builder OAuth2Authorization authorization = builder
.tokens(OAuth2Tokens.builder().token(authorizationCode).build()) .token(authorizationCode)
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes())
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
@ -264,7 +263,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
this.codeGenerator.generateKey(), issuedAt, expiresAt); this.codeGenerator.generateKey(), issuedAt, expiresAt);
OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization()) OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization())
.tokens(OAuth2Tokens.builder().token(authorizationCode).build()) .token(authorizationCode)
.attributes(attrs -> { .attributes(attrs -> {
attrs.remove(OAuth2AuthorizationAttributeNames.STATE); attrs.remove(OAuth2AuthorizationAttributeNames.STATE);
attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes()); attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes());

View File

@ -198,7 +198,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -225,7 +225,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -252,7 +252,7 @@ public class OAuth2AuthorizationCodeGrantTests {
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken( verify(authorizationService).findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE)); eq(TokenType.AUTHORIZATION_CODE));
verify(authorizationService).save(any()); verify(authorizationService).save(any());
@ -286,7 +286,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization = authorizationCaptor.getValue(); OAuth2Authorization authorization = authorizationCaptor.getValue();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -303,7 +303,7 @@ public class OAuth2AuthorizationCodeGrantTests {
verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService, times(2)).findByToken( verify(authorizationService, times(2)).findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE)); eq(TokenType.AUTHORIZATION_CODE));
verify(authorizationService, times(2)).save(any()); verify(authorizationService, times(2)).save(any());
} }
@ -318,7 +318,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -343,7 +343,7 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2Authorization authorization) { OAuth2Authorization authorization) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(); MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
parameters.set(OAuth2ParameterNames.CODE, authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()); parameters.set(OAuth2ParameterNames.CODE, authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue());
parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
return parameters; return parameters;
} }

View File

@ -126,7 +126,7 @@ public class OAuth2RefreshTokenGrantTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
@ -146,7 +146,7 @@ public class OAuth2RefreshTokenGrantTests {
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken( verify(authorizationService).findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN)); eq(TokenType.REFRESH_TOKEN));
verify(authorizationService).save(any()); verify(authorizationService).save(any());
@ -169,7 +169,7 @@ public class OAuth2RefreshTokenGrantTests {
private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) { private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(); MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()); parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());
parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, authorization.getTokens().getRefreshToken().getTokenValue()); parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, authorization.getRefreshToken().getToken().getTokenValue());
return parameters; return parameters;
} }

View File

@ -104,7 +104,7 @@ public class OAuth2TokenRevocationTests {
.thenReturn(registeredClient); .thenReturn(registeredClient);
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
OAuth2RefreshToken token = authorization.getTokens().getRefreshToken(); OAuth2RefreshToken token = authorization.getRefreshToken().getToken();
TokenType tokenType = TokenType.REFRESH_TOKEN; TokenType tokenType = TokenType.REFRESH_TOKEN;
when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization); when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization);
@ -121,10 +121,10 @@ public class OAuth2TokenRevocationTests {
verify(authorizationService).save(authorizationCaptor.capture()); verify(authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue(); assertThat(refreshToken.isInvalidated()).isTrue();
OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); assertThat(accessToken.isInvalidated()).isTrue();
} }
@Test @Test
@ -147,7 +147,7 @@ public class OAuth2TokenRevocationTests {
.thenReturn(registeredClient); .thenReturn(registeredClient);
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
OAuth2AccessToken token = authorization.getTokens().getAccessToken(); OAuth2AccessToken token = authorization.getAccessToken().getToken();
TokenType tokenType = TokenType.ACCESS_TOKEN; TokenType tokenType = TokenType.ACCESS_TOKEN;
when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization); when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization);
@ -164,10 +164,10 @@ public class OAuth2TokenRevocationTests {
verify(authorizationService).save(authorizationCaptor.capture()); verify(authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); assertThat(accessToken.isInvalidated()).isTrue();
OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse(); assertThat(refreshToken.isInvalidated()).isFalse();
} }
private static MultiValueMap<String, String> getTokenRevocationRequestParameters(AbstractOAuth2Token token, TokenType tokenType) { private static MultiValueMap<String, String> getTokenRevocationRequestParameters(AbstractOAuth2Token token, TokenType tokenType) {

View File

@ -183,7 +183,7 @@ public class OidcTests {
OAuth2Authorization authorization = authorizationCaptor.getValue(); OAuth2Authorization authorization = authorizationCaptor.getValue();
when(authorizationService.findByToken( when(authorizationService.findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE))) eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -204,7 +204,7 @@ public class OidcTests {
verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken( verify(authorizationService).findByToken(
eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()), eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
eq(TokenType.AUTHORIZATION_CODE)); eq(TokenType.AUTHORIZATION_CODE));
verify(authorizationService, times(2)).save(any()); verify(authorizationService, times(2)).save(any());
@ -238,7 +238,7 @@ public class OidcTests {
OAuth2Authorization authorization) { OAuth2Authorization authorization) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(); MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
parameters.set(OAuth2ParameterNames.CODE, authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()); parameters.set(OAuth2ParameterNames.CODE, authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue());
parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
return parameters; return parameters;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,17 +15,17 @@
*/ */
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -59,7 +59,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void saveWhenAuthorizationProvidedThenSaved() { public void saveWhenAuthorizationProvidedThenSaved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build()) .token(AUTHORIZATION_CODE)
.build(); .build();
this.authorizationService.save(expectedAuthorization); this.authorizationService.save(expectedAuthorization);
@ -79,7 +79,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void removeWhenAuthorizationProvidedThenRemoved() { public void removeWhenAuthorizationProvidedThenRemoved() {
OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build()) .token(AUTHORIZATION_CODE)
.build(); .build();
this.authorizationService.save(expectedAuthorization); this.authorizationService.save(expectedAuthorization);
@ -120,7 +120,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
public void findByTokenWhenAuthorizationCodeExistsThenFound() { public void findByTokenWhenAuthorizationCodeExistsThenFound() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build()) .token(AUTHORIZATION_CODE)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
@ -137,7 +137,8 @@ public class InMemoryOAuth2AuthorizationServiceTests {
"access-token", Instant.now().minusSeconds(60), Instant.now()); "access-token", Instant.now().minusSeconds(60), Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(accessToken).build()) .token(AUTHORIZATION_CODE)
.accessToken(accessToken)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
@ -153,7 +154,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now());
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().refreshToken(refreshToken).build()) .refreshToken(refreshToken)
.build(); .build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,16 +15,16 @@
*/ */
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import org.junit.Test; import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -62,15 +62,16 @@ public class OAuth2AuthorizationTests {
public void fromWhenAuthorizationProvidedThenCopied() { public void fromWhenAuthorizationProvidedThenCopied() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build()) .token(AUTHORIZATION_CODE)
.accessToken(ACCESS_TOKEN)
.build(); .build();
OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build(); OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
assertThat(authorizationResult.getTokens().getToken(OAuth2AuthorizationCode.class)) assertThat(authorizationResult.getToken(OAuth2AuthorizationCode.class))
.isEqualTo(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)); .isEqualTo(authorization.getToken(OAuth2AuthorizationCode.class));
assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes()); assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
} }
@ -103,13 +104,15 @@ public class OAuth2AuthorizationTests {
public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.principalName(PRINCIPAL_NAME) .principalName(PRINCIPAL_NAME)
.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).refreshToken(REFRESH_TOKEN).build()) .token(AUTHORIZATION_CODE)
.accessToken(ACCESS_TOKEN)
.refreshToken(REFRESH_TOKEN)
.build(); .build();
assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
assertThat(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isEqualTo(AUTHORIZATION_CODE); assertThat(authorization.getToken(OAuth2AuthorizationCode.class).getToken()).isEqualTo(AUTHORIZATION_CODE);
assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN); assertThat(authorization.getAccessToken().getToken()).isEqualTo(ACCESS_TOKEN);
assertThat(authorization.getTokens().getRefreshToken()).isEqualTo(REFRESH_TOKEN); assertThat(authorization.getRefreshToken().getToken()).isEqualTo(REFRESH_TOKEN);
} }
} }

View File

@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
/** /**
* @author Joe Grandja * @author Joe Grandja
@ -62,7 +61,9 @@ public class TestOAuth2Authorizations {
.build(); .build();
return OAuth2Authorization.withRegisteredClient(registeredClient) return OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName("principal") .principalName("principal")
.tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build()) .token(authorizationCode)
.accessToken(accessToken)
.refreshToken(refreshToken)
.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, .attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL,
new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B")) new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"))

View File

@ -50,8 +50,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -172,8 +170,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture()); verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue(); updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCode.isInvalidated()).isTrue();
} }
@Test @Test
@ -201,9 +200,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120)); AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120));
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.tokens(OAuth2Tokens.builder() .token(authorizationCode, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
.token(authorizationCode, OAuth2TokenMetadata.builder().invalidated().build())
.build())
.build(); .build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE))) when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
.thenReturn(authorization); .thenReturn(authorization);
@ -265,11 +262,11 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull(); assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue(); assertThat(authorizationCode.isInvalidated()).isTrue();
} }
@Test @Test
@ -321,15 +318,15 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull(); assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue(); assertThat(authorizationCode.isInvalidated()).isTrue();
OidcIdToken idToken = updatedAuthorization.getTokens().getToken(OidcIdToken.class); OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
assertThat(idToken).isNotNull(); assertThat(idToken).isNotNull();
assertThat(accessTokenAuthentication.getAdditionalParameters()) assertThat(accessTokenAuthentication.getAdditionalParameters())
.containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getTokenValue())); .containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue()));
} }
@Test @Test
@ -362,12 +359,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
verify(this.authorizationService).save(authorizationCaptor.capture()); verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
Instant expectedAccessTokenExpiresAt = accessTokenAuthentication.getAccessToken().getIssuedAt().plus(accessTokenTTL); Instant expectedAccessTokenExpiresAt = accessTokenAuthentication.getAccessToken().getIssuedAt().plus(accessTokenTTL);
assertThat(accessTokenAuthentication.getAccessToken().getExpiresAt()).isBetween( assertThat(accessTokenAuthentication.getAccessToken().getExpiresAt()).isBetween(
expectedAccessTokenExpiresAt.minusSeconds(1), expectedAccessTokenExpiresAt.plusSeconds(1)); expectedAccessTokenExpiresAt.minusSeconds(1), expectedAccessTokenExpiresAt.plusSeconds(1));
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt().plus(refreshTokenTTL); Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt().plus(refreshTokenTTL);
assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isBetween( assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isBetween(
expectedRefreshTokenExpiresAt.minusSeconds(1), expectedRefreshTokenExpiresAt.plusSeconds(1)); expectedRefreshTokenExpiresAt.minusSeconds(1), expectedRefreshTokenExpiresAt.plusSeconds(1));

View File

@ -204,10 +204,10 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId()); assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId());
assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName()); assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName());
assertThat(authorization.getTokens().getAccessToken()).isNotNull(); assertThat(authorization.getAccessToken()).isNotNull();
assertThat(authorization.getTokens().getAccessToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes()); assertThat(authorization.getAccessToken().getToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getAccessToken().getToken());
} }
private static Jwt createJwt(Set<String> scope) { private static Jwt createJwt(Set<String> scope) {

View File

@ -47,8 +47,6 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
@ -120,13 +118,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -149,11 +147,11 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotEqualTo(authorization.getTokens().getAccessToken()); assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken());
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
// By default, refresh token is reused // By default, refresh token is reused
assertThat(updatedAuthorization.getTokens().getRefreshToken()).isEqualTo(authorization.getTokens().getRefreshToken()); assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
} }
@Test @Test
@ -163,13 +161,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
.build(); .build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -178,8 +176,8 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
verify(this.authorizationService).save(authorizationCaptor.capture()); verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken()); assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotEqualTo(authorization.getTokens().getRefreshToken()); assertThat(updatedAuthorization.getRefreshToken()).isNotEqualTo(authorization.getRefreshToken());
} }
@Test @Test
@ -187,7 +185,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
@ -196,7 +194,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
Set<String> requestedScopes = new HashSet<>(authorizedScopes); Set<String> requestedScopes = new HashSet<>(authorizedScopes);
requestedScopes.remove("email"); requestedScopes.remove("email");
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, requestedScopes); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -209,7 +207,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
@ -218,7 +216,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
Set<String> requestedScopes = new HashSet<>(authorizedScopes); Set<String> requestedScopes = new HashSet<>(authorizedScopes);
requestedScopes.add("unauthorized"); requestedScopes.add("unauthorized");
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, requestedScopes); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -276,14 +274,14 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
TestRegisteredClients.registeredClient2().build()); TestRegisteredClients.registeredClient2().build());
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -299,13 +297,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
.build(); .build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -320,16 +318,15 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
OAuth2RefreshToken expiredRefreshToken = new OAuth2RefreshToken2( OAuth2RefreshToken expiredRefreshToken = new OAuth2RefreshToken2(
"expired-refresh-token", Instant.now().minusSeconds(120), Instant.now().minusSeconds(60)); "expired-refresh-token", Instant.now().minusSeconds(120), Instant.now().minusSeconds(60));
OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens()).refreshToken(expiredRefreshToken).build(); authorization = OAuth2Authorization.from(authorization).token(expiredRefreshToken).build();
authorization = OAuth2Authorization.from(authorization).tokens(tokens).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -343,20 +340,17 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2( OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2(
"refresh-token", Instant.now().minusSeconds(120), Instant.now().plusSeconds(1000)); "refresh-token", Instant.now().minusSeconds(120), Instant.now().plusSeconds(1000));
OAuth2TokenMetadata metadata = OAuth2TokenMetadata.builder().invalidated().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
.tokens(OAuth2Tokens.builder() .token(refreshToken, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
.refreshToken(refreshToken, metadata)
.build())
.build(); .build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) eq(TokenType.REFRESH_TOKEN)))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -136,13 +137,13 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
registeredClient).build(); registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getRefreshToken().getToken().getTokenValue()),
isNull())) isNull()))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, TokenType.REFRESH_TOKEN.getValue()); authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, TokenType.REFRESH_TOKEN.getValue());
OAuth2TokenRevocationAuthenticationToken authenticationResult = OAuth2TokenRevocationAuthenticationToken authenticationResult =
(OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -152,10 +153,10 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
verify(this.authorizationService).save(authorizationCaptor.capture()); verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue(); assertThat(refreshToken.isInvalidated()).isTrue();
OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); assertThat(accessToken.isInvalidated()).isTrue();
} }
@Test @Test
@ -164,13 +165,13 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
registeredClient).build(); registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getAccessToken().getTokenValue()), eq(authorization.getAccessToken().getToken().getTokenValue()),
isNull())) isNull()))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken( OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
authorization.getTokens().getAccessToken().getTokenValue(), clientPrincipal, TokenType.ACCESS_TOKEN.getValue()); authorization.getAccessToken().getToken().getTokenValue(), clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
OAuth2TokenRevocationAuthenticationToken authenticationResult = OAuth2TokenRevocationAuthenticationToken authenticationResult =
(OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication); (OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@ -180,9 +181,9 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
verify(this.authorizationService).save(authorizationCaptor.capture()); verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue(); assertThat(accessToken.isInvalidated()).isTrue();
OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken(); OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse(); assertThat(refreshToken.isInvalidated()).isFalse();
} }
} }

View File

@ -1,74 +0,0 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import org.junit.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link OAuth2TokenMetadata}.
*
* @author Joe Grandja
*/
public class OAuth2TokenMetadataTests {
@Test
public void metadataWhenNameNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() ->
OAuth2TokenMetadata.builder()
.metadata(null, "value"))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("name cannot be empty");
}
@Test
public void metadataWhenValueNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() ->
OAuth2TokenMetadata.builder()
.metadata("name", null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("value cannot be null");
}
@Test
public void getMetadataWhenNameNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2TokenMetadata.builder().build().getMetadata(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("name cannot be empty");
}
@Test
public void buildWhenDefaultThenDefaultsAreSet() {
OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder().build();
assertThat(tokenMetadata.getMetadata()).hasSize(1);
assertThat(tokenMetadata.isInvalidated()).isFalse();
}
@Test
public void buildWhenMetadataProvidedThenMetadataIsSet() {
OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder()
.invalidated()
.metadata("name1", "value1")
.metadata(metadata -> metadata.put("name2", "value2"))
.build();
assertThat(tokenMetadata.getMetadata()).hasSize(3);
assertThat(tokenMetadata.isInvalidated()).isTrue();
assertThat(tokenMetadata.<String>getMetadata("name1")).isEqualTo("value1");
assertThat(tokenMetadata.<String>getMetadata("name2")).isEqualTo("value2");
}
}

View File

@ -1,195 +0,0 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.token;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashSet;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link OAuth2Tokens}.
*
* @author Joe Grandja
*/
public class OAuth2TokensTests {
private OAuth2AccessToken accessToken;
private OAuth2RefreshToken refreshToken;
private OidcIdToken idToken;
@Before
public void setUp() {
Instant issuedAt = Instant.now();
this.accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER,
"access-token",
issuedAt,
issuedAt.plus(Duration.ofMinutes(5)),
new HashSet<>(Arrays.asList("read", "write")));
this.refreshToken = new OAuth2RefreshToken(
"refresh-token",
issuedAt);
this.idToken = OidcIdToken.withTokenValue("id-token")
.issuer("https://provider.com")
.subject("subject")
.issuedAt(issuedAt)
.expiresAt(issuedAt.plus(Duration.ofMinutes(30)))
.build();
}
@Test
public void accessTokenWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.builder().accessToken(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("token cannot be null");
}
@Test
public void refreshTokenWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.builder().refreshToken(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("token cannot be null");
}
@Test
public void tokenWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.builder().token(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("token cannot be null");
}
@Test
public void getTokenWhenTokenTypeNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((Class<OAuth2AccessToken>) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("tokenType cannot be null");
}
@Test
public void getTokenWhenTokenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((String) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("token cannot be empty");
}
@Test
public void getTokenMetadataWhenTokenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.builder().build().getTokenMetadata(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("token cannot be null");
}
@Test
public void fromWhenTokensNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2Tokens.from(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("tokens cannot be null");
}
@Test
public void fromWhenTokensProvidedThenCopied() {
OAuth2Tokens tokens = OAuth2Tokens.builder()
.accessToken(this.accessToken)
.refreshToken(this.refreshToken)
.token(this.idToken)
.build();
OAuth2Tokens tokensResult = OAuth2Tokens.from(tokens).build();
assertThat(tokensResult.getAccessToken()).isEqualTo(tokens.getAccessToken());
assertThat(tokensResult.getTokenMetadata(tokensResult.getAccessToken()))
.isEqualTo(tokens.getTokenMetadata(tokens.getAccessToken()));
assertThat(tokensResult.getRefreshToken()).isEqualTo(tokens.getRefreshToken());
assertThat(tokensResult.getTokenMetadata(tokensResult.getRefreshToken()))
.isEqualTo(tokens.getTokenMetadata(tokens.getRefreshToken()));
assertThat(tokensResult.getToken(OidcIdToken.class)).isEqualTo(tokens.getToken(OidcIdToken.class));
assertThat(tokensResult.getTokenMetadata(tokensResult.getToken(OidcIdToken.class)))
.isEqualTo(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)));
}
@Test
public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() {
OAuth2Tokens tokens = OAuth2Tokens.builder()
.accessToken(this.accessToken)
.refreshToken(this.refreshToken)
.token(this.idToken)
.build();
assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
assertThat(tokenMetadata.isInvalidated()).isFalse();
assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken);
tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken());
assertThat(tokenMetadata.isInvalidated()).isFalse();
assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken);
tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class));
assertThat(tokenMetadata.isInvalidated()).isFalse();
}
@Test
public void buildWhenTokenMetadataProvidedThenTokenMetadataIsSet() {
OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build();
OAuth2Tokens tokens = OAuth2Tokens.builder()
.accessToken(this.accessToken, expectedTokenMetadata)
.refreshToken(this.refreshToken, expectedTokenMetadata)
.token(this.idToken, expectedTokenMetadata)
.build();
assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken);
tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken());
assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken);
tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class));
assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
}
@Test
public void getTokenMetadataWhenTokenNotFoundThenNull() {
OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build();
OAuth2Tokens tokens = OAuth2Tokens.builder()
.accessToken(this.accessToken, expectedTokenMetadata)
.build();
assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
OAuth2AccessToken otherAccessToken = new OAuth2AccessToken(
this.accessToken.getTokenType(),
"other-access-token",
this.accessToken.getIssuedAt(),
this.accessToken.getExpiresAt(),
this.accessToken.getScopes());
assertThat(tokens.getTokenMetadata(otherAccessToken)).isNull();
}
}

View File

@ -470,7 +470,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication); .isEqualTo(this.authentication);
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization.getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCode).isNotNull(); assertThat(authorizationCode).isNotNull();
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@ -519,7 +519,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL)) assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
.isEqualTo(this.authentication); .isEqualTo(this.authentication);
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization.getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCode).isNotNull(); assertThat(authorizationCode).isNotNull();
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST); OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@ -795,7 +795,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
assertThat(updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isNotNull(); assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull();
assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)) assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))
.isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST)); .isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST));