diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 9c84a74..4e77698 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -16,7 +16,6 @@ package org.springframework.security.oauth2.server.authorization; import org.springframework.lang.Nullable; -import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.util.Assert; import java.io.Serializable; @@ -66,8 +65,8 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza } else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) { return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)); } else if (TokenType.ACCESS_TOKEN.equals(tokenType)) { - return authorization.getAccessToken() != null && - authorization.getAccessToken().getTokenValue().equals(token); + return authorization.getTokens().getAccessToken() != null && + authorization.getTokens().getAccessToken().getTokenValue().equals(token); } return false; } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index 78d1997..a39e90a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -15,9 +15,9 @@ */ package org.springframework.security.oauth2.server.authorization; -import org.springframework.security.oauth2.server.authorization.Version; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.util.Assert; import java.io.Serializable; @@ -36,13 +36,17 @@ import java.util.function.Consumer; * @author Krisztian Toth * @since 0.0.1 * @see RegisteredClient - * @see OAuth2AccessToken + * @see OAuth2Tokens */ public class OAuth2Authorization implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String registeredClientId; private String principalName; + private OAuth2Tokens tokens; + + @Deprecated private OAuth2AccessToken accessToken; + private Map attributes; protected OAuth2Authorization() { @@ -66,13 +70,23 @@ public class OAuth2Authorization implements Serializable { return this.principalName; } + /** + * Returns the {@link OAuth2Tokens}. + * + * @return the {@link OAuth2Tokens} + */ + public OAuth2Tokens getTokens() { + return this.tokens; + } + /** * Returns the {@link OAuth2AccessToken access token} credential. * * @return the {@link OAuth2AccessToken} */ + @Deprecated public OAuth2AccessToken getAccessToken() { - return this.accessToken; + return getTokens().getAccessToken(); } /** @@ -108,13 +122,13 @@ public class OAuth2Authorization implements Serializable { OAuth2Authorization that = (OAuth2Authorization) obj; return Objects.equals(this.registeredClientId, that.registeredClientId) && Objects.equals(this.principalName, that.principalName) && - Objects.equals(this.accessToken, that.accessToken) && + Objects.equals(this.tokens, that.tokens) && Objects.equals(this.attributes, that.attributes); } @Override public int hashCode() { - return Objects.hash(this.registeredClientId, this.principalName, this.accessToken, this.attributes); + return Objects.hash(this.registeredClientId, this.principalName, this.tokens, this.attributes); } /** @@ -138,7 +152,7 @@ public class OAuth2Authorization implements Serializable { Assert.notNull(authorization, "authorization cannot be null"); return new Builder(authorization.getRegisteredClientId()) .principalName(authorization.getPrincipalName()) - .accessToken(authorization.getAccessToken()) + .tokens(authorization.getTokens()) .attributes(attrs -> attrs.putAll(authorization.getAttributes())); } @@ -149,7 +163,11 @@ public class OAuth2Authorization implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String registeredClientId; private String principalName; + private OAuth2Tokens tokens; + + @Deprecated private OAuth2AccessToken accessToken; + private Map attributes = new HashMap<>(); protected Builder(String registeredClientId) { @@ -167,12 +185,24 @@ public class OAuth2Authorization implements Serializable { return this; } + /** + * Sets the {@link OAuth2Tokens}. + * + * @param tokens the {@link OAuth2Tokens} + * @return the {@link Builder} + */ + public Builder tokens(OAuth2Tokens tokens) { + this.tokens = tokens; + return this; + } + /** * Sets the {@link OAuth2AccessToken access token} credential. * * @param accessToken the {@link OAuth2AccessToken} * @return the {@link Builder} */ + @Deprecated public Builder accessToken(OAuth2AccessToken accessToken) { this.accessToken = accessToken; return this; @@ -215,7 +245,14 @@ public class OAuth2Authorization implements Serializable { OAuth2Authorization authorization = new OAuth2Authorization(); authorization.registeredClientId = this.registeredClientId; authorization.principalName = this.principalName; - authorization.accessToken = this.accessToken; + if (this.tokens == null) { + OAuth2Tokens.Builder builder = OAuth2Tokens.builder(); + if (this.accessToken != null) { + builder.accessToken(this.accessToken); + } + this.tokens = builder.build(); + } + authorization.tokens = this.tokens; authorization.attributes = Collections.unmodifiableMap(this.attributes); return authorization; } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index c7a2053..e5153e2 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -35,6 +35,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -143,7 +144,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica authorization = OAuth2Authorization.from(authorization) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) - .accessToken(accessToken) + .tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index 622475b..f07cb46 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -32,6 +32,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -129,7 +130,7 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) .attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt) .principalName(clientPrincipal.getName()) - .accessToken(accessToken) + .tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadata.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadata.java new file mode 100644 index 0000000..e98ac7d --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadata.java @@ -0,0 +1,169 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import org.springframework.security.oauth2.server.authorization.Version; +import org.springframework.util.Assert; + +import java.io.Serializable; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; + +/** + * Holds metadata associated to an OAuth 2.0 Token. + * + * @author Joe Grandja + * @since 0.0.3 + * @see OAuth2Tokens + */ +public class OAuth2TokenMetadata implements Serializable { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + protected static final String TOKEN_METADATA_BASE = "token.metadata."; + + /** + * The name of the metadata that indicates if the token has been invalidated. + */ + public static final String INVALIDATED = TOKEN_METADATA_BASE.concat("invalidated"); + + private final Map metadata; + + protected OAuth2TokenMetadata(Map metadata) { + this.metadata = Collections.unmodifiableMap(new HashMap<>(metadata)); + } + + /** + * Returns {@code true} if the token has been invalidated (e.g. revoked). + * The default is {@code false}. + * + * @return {@code true} if the token has been invalidated, {@code false} otherwise + */ + public boolean isInvalidated() { + return getMetadata(INVALIDATED); + } + + /** + * Returns the value of the metadata associated to the token. + * + * @param name the name of the metadata + * @param the type of the metadata + * @return the value of the metadata, or {@code null} if not available + */ + @SuppressWarnings("unchecked") + public T getMetadata(String name) { + Assert.hasText(name, "name cannot be empty"); + return (T) this.metadata.get(name); + } + + /** + * Returns the metadata associated to the token. + * + * @return a {@code Map} of the metadata + */ + public Map getMetadata() { + return this.metadata; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OAuth2TokenMetadata that = (OAuth2TokenMetadata) obj; + return Objects.equals(this.metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(this.metadata); + } + + /** + * Returns a new {@link Builder}. + * + * @return the {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link OAuth2TokenMetadata}. + */ + public static class Builder implements Serializable { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private final Map metadata = defaultMetadata(); + + protected Builder() { + } + + /** + * Set the token as invalidated (e.g. revoked). + * + * @return the {@link Builder} + */ + public Builder invalidated() { + metadata(INVALIDATED, true); + return this; + } + + /** + * Adds a metadata associated to the token. + * + * @param name the name of the metadata + * @param value the value of the metadata + * @return the {@link Builder} + */ + public Builder metadata(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.metadata.put(name, value); + return this; + } + + /** + * A {@code Consumer} of the metadata {@code Map} + * allowing the ability to add, replace, or remove. + * + * @param metadataConsumer a {@link Consumer} of the metadata {@code Map} + * @return the {@link Builder} + */ + public Builder metadata(Consumer> metadataConsumer) { + metadataConsumer.accept(this.metadata); + return this; + } + + /** + * Builds a new {@link OAuth2TokenMetadata}. + * + * @return the {@link OAuth2TokenMetadata} + */ + public OAuth2TokenMetadata build() { + return new OAuth2TokenMetadata(this.metadata); + } + + protected static Map defaultMetadata() { + Map metadata = new HashMap<>(); + metadata.put(INVALIDATED, false); + return metadata; + } + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java new file mode 100644 index 0000000..3da46e2 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java @@ -0,0 +1,279 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.Version; +import org.springframework.util.Assert; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * A container for OAuth 2.0 Tokens. + * + * @author Joe Grandja + * @since 0.0.3 + * @see OAuth2Authorization + * @see OAuth2TokenMetadata + * @see AbstractOAuth2Token + * @see OAuth2AccessToken + * @see OAuth2RefreshToken + */ +public class OAuth2Tokens implements Serializable { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private final Map, OAuth2TokenHolder> tokens; + + protected OAuth2Tokens(Map, OAuth2TokenHolder> tokens) { + this.tokens = new HashMap<>(tokens); + } + + /** + * Returns the {@link OAuth2AccessToken access token}. + * + * @return the {@link OAuth2AccessToken}, or {@code null} if not available + */ + @Nullable + public OAuth2AccessToken getAccessToken() { + return getToken(OAuth2AccessToken.class); + } + + /** + * Returns the {@link OAuth2RefreshToken refresh token}. + * + * @return the {@link OAuth2RefreshToken}, or {@code null} if not available + */ + @Nullable + public OAuth2RefreshToken getRefreshToken() { + return getToken(OAuth2RefreshToken.class); + } + + /** + * Returns the token specified by {@code tokenType}. + * + * @param tokenType the token type + * @param the type of the token + * @return the token, or {@code null} if not available + */ + @Nullable + @SuppressWarnings("unchecked") + public T getToken(Class tokenType) { + Assert.notNull(tokenType, "tokenType cannot be null"); + OAuth2TokenHolder tokenHolder = this.tokens.get(tokenType); + return tokenHolder != null ? (T) tokenHolder.getToken() : null; + } + + /** + * Returns the token metadata associated to the provided {@code token}. + * + * @param token the token + * @param the type of the token + * @return the token metadata, or {@code null} if not available + */ + @Nullable + public OAuth2TokenMetadata getTokenMetadata(T token) { + Assert.notNull(token, "token cannot be null"); + OAuth2TokenHolder tokenHolder = this.tokens.get(token.getClass()); + return (tokenHolder != null && tokenHolder.getToken().equals(token)) ? + tokenHolder.getTokenMetadata() : null; + } + + /** + * Invalidates all tokens. + */ + public void invalidate() { + this.tokens.values().forEach(tokenHolder -> invalidate(tokenHolder.getToken())); + } + + /** + * Invalidates the token matching the provided {@code token}. + * + * @param token the token + * @param the type of the token + */ + public void invalidate(T token) { + Assert.notNull(token, "token cannot be null"); + this.tokens.computeIfPresent(token.getClass(), + (tokenType, tokenHolder) -> + new OAuth2TokenHolder( + tokenHolder.getToken(), + OAuth2TokenMetadata.builder().invalidated().build()) + ); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OAuth2Tokens that = (OAuth2Tokens) obj; + return Objects.equals(this.tokens, that.tokens); + } + + @Override + public int hashCode() { + return Objects.hash(this.tokens); + } + + /** + * Returns a new {@link Builder}. + * + * @return the {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link OAuth2Tokens}. + */ + public static class Builder implements Serializable { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private final Map, OAuth2TokenHolder> tokens = new HashMap<>(); + + protected Builder() { + } + + /** + * Sets the {@link OAuth2AccessToken access token}. + * + * @param accessToken the {@link OAuth2AccessToken} + * @return the {@link Builder} + */ + public Builder accessToken(OAuth2AccessToken accessToken) { + return addToken(accessToken, null); + } + + /** + * Sets the {@link OAuth2AccessToken access token} and associated {@link OAuth2TokenMetadata token metadata}. + * + * @param accessToken the {@link OAuth2AccessToken} + * @param tokenMetadata the {@link OAuth2TokenMetadata} + * @return the {@link Builder} + */ + public Builder accessToken(OAuth2AccessToken accessToken, OAuth2TokenMetadata tokenMetadata) { + return addToken(accessToken, tokenMetadata); + } + + /** + * Sets the {@link OAuth2RefreshToken refresh token}. + * + * @param refreshToken the {@link OAuth2RefreshToken} + * @return the {@link Builder} + */ + public Builder refreshToken(OAuth2RefreshToken refreshToken) { + return addToken(refreshToken, null); + } + + /** + * Sets the {@link OAuth2RefreshToken refresh token} and associated {@link OAuth2TokenMetadata token metadata}. + * + * @param refreshToken the {@link OAuth2RefreshToken} + * @param tokenMetadata the {@link OAuth2TokenMetadata} + * @return the {@link Builder} + */ + public Builder refreshToken(OAuth2RefreshToken refreshToken, OAuth2TokenMetadata tokenMetadata) { + return addToken(refreshToken, tokenMetadata); + } + + /** + * Sets the token. + * + * @param token the token + * @param the type of the token + * @return the {@link Builder} + */ + public Builder token(T token) { + return addToken(token, null); + } + + /** + * Sets the token and associated {@link OAuth2TokenMetadata token metadata}. + * + * @param token the token + * @param tokenMetadata the {@link OAuth2TokenMetadata} + * @param the type of the token + * @return the {@link Builder} + */ + public Builder token(T token, OAuth2TokenMetadata tokenMetadata) { + return addToken(token, tokenMetadata); + } + + protected Builder addToken(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) { + Assert.notNull(token, "token cannot be null"); + if (tokenMetadata == null) { + tokenMetadata = OAuth2TokenMetadata.builder().build(); + } + this.tokens.put(token.getClass(), new OAuth2TokenHolder(token, tokenMetadata)); + return this; + } + + /** + * Builds a new {@link OAuth2Tokens}. + * + * @return the {@link OAuth2Tokens} + */ + public OAuth2Tokens build() { + return new OAuth2Tokens(this.tokens); + } + } + + protected static class OAuth2TokenHolder implements Serializable { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private final AbstractOAuth2Token token; + private final OAuth2TokenMetadata tokenMetadata; + + protected OAuth2TokenHolder(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) { + this.token = token; + this.tokenMetadata = tokenMetadata; + } + + protected AbstractOAuth2Token getToken() { + return this.token; + } + + protected OAuth2TokenMetadata getTokenMetadata() { + return this.tokenMetadata; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OAuth2TokenHolder that = (OAuth2TokenHolder) obj; + return Objects.equals(this.token, that.token) && + Objects.equals(this.tokenMetadata, that.tokenMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(this.token, this.tokenMetadata); + } + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index 9a05ba8..603c7f8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; @@ -129,7 +130,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) - .accessToken(accessToken) + .tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) .build(); this.authorizationService.save(authorization); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index 6d9a041..2e5a78f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -19,6 +19,7 @@ import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; @@ -57,14 +58,14 @@ public class OAuth2AuthorizationTests { public void fromWhenAuthorizationProvidedThenCopied() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .accessToken(ACCESS_TOKEN) + .tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build()) .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) .build(); OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build(); assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); - assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken()); + assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes()); } @@ -97,13 +98,13 @@ public class OAuth2AuthorizationTests { public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) .principalName(PRINCIPAL_NAME) - .accessToken(ACCESS_TOKEN) + .tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build()) .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) .build(); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); - assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN); + assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN); assertThat(authorization.getAttributes()).containsExactly( entry(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 1b623cb..48a85b8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -19,6 +19,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens; import java.time.Instant; import java.util.Collections; @@ -52,7 +53,7 @@ public class TestOAuth2Authorizations { .build(); return OAuth2Authorization.withRegisteredClient(registeredClient) .principalName("principal") - .accessToken(accessToken) + .tokens(OAuth2Tokens.builder().accessToken(accessToken).build()) .attribute(OAuth2AuthorizationAttributeNames.CODE, "code") .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest) .attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes()); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 28f1ebe..22be631 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -203,8 +203,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); - assertThat(updatedAuthorization.getAccessToken()).isNotNull(); - assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()); + assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull(); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken()); } private static Jwt createJwt() { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index 3957b63..da2f58d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -156,10 +156,10 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId()); assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName()); - assertThat(authorization.getAccessToken()).isNotNull(); - assertThat(authorization.getAccessToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes()); + assertThat(authorization.getTokens().getAccessToken()).isNotNull(); + assertThat(authorization.getTokens().getAccessToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes()); assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); - assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getAccessToken()); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken()); } private static Jwt createJwt() { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadataTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadataTests.java new file mode 100644 index 0000000..5d891e9 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadataTests.java @@ -0,0 +1,74 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2TokenMetadata}. + * + * @author Joe Grandja + */ +public class OAuth2TokenMetadataTests { + + @Test + public void metadataWhenNameNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + OAuth2TokenMetadata.builder() + .metadata(null, "value")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be empty"); + } + + @Test + public void metadataWhenValueNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + OAuth2TokenMetadata.builder() + .metadata("name", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + + @Test + public void getMetadataWhenNameNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2TokenMetadata.builder().build().getMetadata(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be empty"); + } + + @Test + public void buildWhenDefaultThenDefaultsAreSet() { + OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder().build(); + assertThat(tokenMetadata.getMetadata()).hasSize(1); + assertThat(tokenMetadata.isInvalidated()).isFalse(); + } + + @Test + public void buildWhenMetadataProvidedThenMetadataIsSet() { + OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder() + .invalidated() + .metadata("name1", "value1") + .metadata(metadata -> metadata.put("name2", "value2")) + .build(); + assertThat(tokenMetadata.getMetadata()).hasSize(3); + assertThat(tokenMetadata.isInvalidated()).isTrue(); + assertThat(tokenMetadata.getMetadata("name1")).isEqualTo("value1"); + assertThat(tokenMetadata.getMetadata("name2")).isEqualTo("value2"); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java new file mode 100644 index 0000000..5a94065 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java @@ -0,0 +1,187 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.token; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2Tokens}. + * + * @author Joe Grandja + */ +public class OAuth2TokensTests { + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + private OidcIdToken idToken; + + @Before + public void setUp() { + Instant issuedAt = Instant.now(); + this.accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, + "access-token", + issuedAt, + issuedAt.plus(Duration.ofMinutes(5)), + new HashSet<>(Arrays.asList("read", "write"))); + this.refreshToken = new OAuth2RefreshToken( + "refresh-token", + issuedAt); + this.idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject("subject") + .issuedAt(issuedAt) + .expiresAt(issuedAt.plus(Duration.ofMinutes(30))) + .build(); + } + + @Test + public void accessTokenWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Tokens.builder().accessToken(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be null"); + } + + @Test + public void refreshTokenWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Tokens.builder().refreshToken(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be null"); + } + + @Test + public void tokenWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Tokens.builder().token(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be null"); + } + + @Test + public void getTokenWhenTokenTypeNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("tokenType cannot be null"); + } + + @Test + public void getTokenMetadataWhenTokenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2Tokens.builder().build().getTokenMetadata(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be null"); + } + + @Test + public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() { + OAuth2Tokens tokens = OAuth2Tokens.builder() + .accessToken(this.accessToken) + .refreshToken(this.refreshToken) + .token(this.idToken) + .build(); + + assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken); + OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken()); + assertThat(tokenMetadata.isInvalidated()).isFalse(); + + assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken); + tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken()); + assertThat(tokenMetadata.isInvalidated()).isFalse(); + + assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken); + tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)); + assertThat(tokenMetadata.isInvalidated()).isFalse(); + } + + @Test + public void buildWhenTokenMetadataProvidedThenTokenMetadataIsSet() { + OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build(); + OAuth2Tokens tokens = OAuth2Tokens.builder() + .accessToken(this.accessToken, expectedTokenMetadata) + .refreshToken(this.refreshToken, expectedTokenMetadata) + .token(this.idToken, expectedTokenMetadata) + .build(); + + assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken); + OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken()); + assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata); + + assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken); + tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken()); + assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata); + + assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken); + tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)); + assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata); + } + + @Test + public void getTokenMetadataWhenTokenNotFoundThenNull() { + OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build(); + OAuth2Tokens tokens = OAuth2Tokens.builder() + .accessToken(this.accessToken, expectedTokenMetadata) + .build(); + + assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken); + OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken()); + assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata); + + OAuth2AccessToken otherAccessToken = new OAuth2AccessToken( + this.accessToken.getTokenType(), + "other-access-token", + this.accessToken.getIssuedAt(), + this.accessToken.getExpiresAt(), + this.accessToken.getScopes()); + assertThat(tokens.getTokenMetadata(otherAccessToken)).isNull(); + } + + @Test + public void invalidateWhenAllTokensThenAllInvalidated() { + OAuth2Tokens tokens = OAuth2Tokens.builder() + .accessToken(this.accessToken) + .refreshToken(this.refreshToken) + .token(this.idToken) + .build(); + tokens.invalidate(); + + assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue(); + assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isTrue(); + assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isTrue(); + } + + @Test + public void invalidateWhenTokenProvidedThenInvalidated() { + OAuth2Tokens tokens = OAuth2Tokens.builder() + .accessToken(this.accessToken) + .refreshToken(this.refreshToken) + .token(this.idToken) + .build(); + tokens.invalidate(this.accessToken); + + assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue(); + assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isFalse(); + assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isFalse(); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 1808be2..1e64dda 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -755,7 +755,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); - assertThat(updatedAuthorization.getAccessToken()).isNotNull(); + assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)).isNotNull(); assertThat(updatedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))