diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 00cc180..67d87b5 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -15,9 +15,7 @@ */ package org.springframework.security.oauth2.server.authorization; -import java.io.Serializable; import java.util.Map; -import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import org.springframework.lang.Nullable; @@ -40,22 +38,27 @@ import org.springframework.util.Assert; * @see OAuth2AuthorizationService */ public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService { - private final Map authorizations = new ConcurrentHashMap<>(); + private final Map authorizations = new ConcurrentHashMap<>(); @Override public void save(OAuth2Authorization authorization) { Assert.notNull(authorization, "authorization cannot be null"); - OAuth2AuthorizationId authorizationId = new OAuth2AuthorizationId( - authorization.getRegisteredClientId(), authorization.getPrincipalName()); - this.authorizations.put(authorizationId, authorization); + Assert.isTrue(!this.authorizations.containsKey(authorization.getId()), + "The authorization must be unique. Found duplicate identifier: " + authorization.getId()); + this.authorizations.put(authorization.getId(), authorization); } @Override public void remove(OAuth2Authorization authorization) { Assert.notNull(authorization, "authorization cannot be null"); - OAuth2AuthorizationId authorizationId = new OAuth2AuthorizationId( - authorization.getRegisteredClientId(), authorization.getPrincipalName()); - this.authorizations.remove(authorizationId, authorization); + this.authorizations.remove(authorization.getId(), authorization); + } + + @Nullable + @Override + public OAuth2Authorization findById(String id) { + Assert.hasText(id, "id cannot be empty"); + return this.authorizations.get(id); } @Nullable @@ -107,33 +110,4 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza authorization.getToken(OAuth2RefreshToken.class); return refreshToken != null && refreshToken.getToken().getTokenValue().equals(token); } - - private static class OAuth2AuthorizationId implements Serializable { - private static final long serialVersionUID = Version.SERIAL_VERSION_UID; - private final String registeredClientId; - private final String principalName; - - private OAuth2AuthorizationId(String registeredClientId, String principalName) { - this.registeredClientId = registeredClientId; - this.principalName = principalName; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - OAuth2AuthorizationId that = (OAuth2AuthorizationId) obj; - return Objects.equals(this.registeredClientId, that.registeredClientId) && - Objects.equals(this.principalName, that.principalName); - } - - @Override - public int hashCode() { - return Objects.hash(this.registeredClientId, this.principalName); - } - } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index 324f4a3..d157ce5 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.UUID; import java.util.function.Consumer; import org.springframework.lang.Nullable; @@ -30,6 +31,7 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * A representation of an OAuth 2.0 Authorization, which holds state related to the authorization granted @@ -55,6 +57,7 @@ public class OAuth2Authorization implements Serializable { public static final String AUTHORIZED_SCOPE_ATTRIBUTE_NAME = OAuth2Authorization.class.getName().concat(".AUTHORIZED_SCOPE"); + private String id; private String registeredClientId; private String principalName; private AuthorizationGrantType authorizationGrantType; @@ -64,6 +67,15 @@ public class OAuth2Authorization implements Serializable { protected OAuth2Authorization() { } + /** + * Returns the identifier for the authorization. + * + * @return the identifier for the authorization + */ + public String getId() { + return this.id; + } + /** * Returns the identifier for the {@link RegisteredClient#getId() registered client}. * @@ -175,7 +187,8 @@ public class OAuth2Authorization implements Serializable { return false; } OAuth2Authorization that = (OAuth2Authorization) obj; - return Objects.equals(this.registeredClientId, that.registeredClientId) && + return Objects.equals(this.id, that.id) && + Objects.equals(this.registeredClientId, that.registeredClientId) && Objects.equals(this.principalName, that.principalName) && Objects.equals(this.authorizationGrantType, that.authorizationGrantType) && Objects.equals(this.tokens, that.tokens) && @@ -184,7 +197,7 @@ public class OAuth2Authorization implements Serializable { @Override public int hashCode() { - return Objects.hash(this.registeredClientId, this.principalName, + return Objects.hash(this.id, this.registeredClientId, this.principalName, this.authorizationGrantType, this.tokens, this.attributes); } @@ -208,6 +221,7 @@ public class OAuth2Authorization implements Serializable { public static Builder from(OAuth2Authorization authorization) { Assert.notNull(authorization, "authorization cannot be null"); return new Builder(authorization.getRegisteredClientId()) + .id(authorization.getId()) .principalName(authorization.getPrincipalName()) .authorizationGrantType(authorization.getAuthorizationGrantType()) .tokens(authorization.tokens) @@ -328,6 +342,7 @@ public class OAuth2Authorization implements Serializable { */ public static class Builder implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private String id; private final String registeredClientId; private String principalName; private AuthorizationGrantType authorizationGrantType; @@ -338,6 +353,17 @@ public class OAuth2Authorization implements Serializable { this.registeredClientId = registeredClientId; } + /** + * Sets the identifier for the authorization. + * + * @param id the identifier for the authorization + * @return the {@link Builder} + */ + public Builder id(String id) { + this.id = id; + return this; + } + /** * Sets the {@code Principal} name of the resource owner (or client). * @@ -458,6 +484,10 @@ public class OAuth2Authorization implements Serializable { Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null"); OAuth2Authorization authorization = new OAuth2Authorization(); + if (!StringUtils.hasText(this.id)) { + this.id = UUID.randomUUID().toString(); + } + authorization.id = this.id; authorization.registeredClientId = this.registeredClientId; authorization.principalName = this.principalName; authorization.authorizationGrantType = this.authorizationGrantType; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java index 969ef60..a22a696 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java @@ -43,6 +43,16 @@ public interface OAuth2AuthorizationService { */ void remove(OAuth2Authorization authorization); + /** + * Returns the {@link OAuth2Authorization} identified by the provided {@code id}, + * or {@code null} if not found. + * + * @param id the authorization identifier + * @return the {@link OAuth2Authorization} if found, otherwise {@code null} + */ + @Nullable + OAuth2Authorization findById(String id); + /** * Returns the {@link OAuth2Authorization} containing the provided {@code token}, * or {@code null} if not found. diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index d88efe2..0d4593a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -40,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class InMemoryOAuth2AuthorizationServiceTests { + private static final String ID = "id"; private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final String PRINCIPAL_NAME = "principal"; private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE; @@ -64,6 +65,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { @Test public void saveWhenAuthorizationProvidedThenSaved() { OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) @@ -75,6 +77,25 @@ public class InMemoryOAuth2AuthorizationServiceTests { assertThat(authorization).isEqualTo(expectedAuthorization); } + @Test + public void saveWhenAuthorizationNotUniqueThenThrowIllegalArgumentException() { + OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(AUTHORIZATION_CODE) + .build(); + this.authorizationService.save(expectedAuthorization); + + OAuth2Authorization authorization = this.authorizationService.findById( + expectedAuthorization.getId()); + assertThat(authorization).isEqualTo(expectedAuthorization); + + assertThatThrownBy(() -> this.authorizationService.save(authorization)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The authorization must be unique. Found duplicate identifier: " + ID); + } + @Test public void removeWhenAuthorizationNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizationService.remove(null)) @@ -85,6 +106,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { @Test public void removeWhenAuthorizationProvidedThenRemoved() { OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) @@ -101,6 +123,13 @@ public class InMemoryOAuth2AuthorizationServiceTests { assertThat(authorization).isNull(); } + @Test + public void findByIdWhenIdNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizationService.findById(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("id cannot be empty"); + } + @Test public void findByTokenWhenTokenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizationService.findByToken(null, AUTHORIZATION_CODE_TOKEN_TYPE)) @@ -112,6 +141,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void findByTokenWhenStateExistsThenFound() { String state = "state"; OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .attribute(OAuth2ParameterNames.STATE, state) @@ -128,6 +158,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { @Test public void findByTokenWhenAuthorizationCodeExistsThenFound() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) @@ -146,6 +177,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now().minusSeconds(60), Instant.now()); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) @@ -164,6 +196,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { public void findByTokenWhenRefreshTokenExistsThenFound() { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now()); OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .refreshToken(refreshToken) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java index 0c819e3..67e98e7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -37,6 +37,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2AuthorizationTests { + private static final String ID = "id"; private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final String PRINCIPAL_NAME = "principal"; private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE; @@ -63,6 +64,7 @@ public class OAuth2AuthorizationTests { @Test public void fromWhenAuthorizationProvidedThenCopied() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) @@ -70,6 +72,7 @@ public class OAuth2AuthorizationTests { .build(); OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build(); + assertThat(authorizationResult.getId()).isEqualTo(authorization.getId()); assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId()); assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName()); assertThat(authorizationResult.getAuthorizationGrantType()).isEqualTo(authorization.getAuthorizationGrantType()); @@ -114,6 +117,7 @@ public class OAuth2AuthorizationTests { @Test public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) .principalName(PRINCIPAL_NAME) .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) .token(AUTHORIZATION_CODE) @@ -121,6 +125,7 @@ public class OAuth2AuthorizationTests { .refreshToken(REFRESH_TOKEN) .build(); + assertThat(authorization.getId()).isEqualTo(ID); assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId()); assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AUTHORIZATION_GRANT_TYPE); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 0bb13be..33546cd 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -63,6 +63,7 @@ public class TestOAuth2Authorizations { .state("state") .build(); return OAuth2Authorization.withRegisteredClient(registeredClient) + .id("id") .principalName("principal") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .token(authorizationCode)