diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index ddc8a02..87e8033 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -16,48 +16,37 @@ package org.springframework.security.oauth2.server.authorization; import org.springframework.lang.Nullable; +import org.springframework.security.core.SpringSecurityCoreVersion2; import org.springframework.util.Assert; -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; +import java.io.Serializable; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; /** * An {@link OAuth2AuthorizationService} that stores {@link OAuth2Authorization}'s in-memory. * * @author Krisztian Toth + * @author Joe Grandja * @since 0.0.1 * @see OAuth2AuthorizationService */ public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService { - private final List authorizations; - - /** - * Constructs an {@code InMemoryOAuth2AuthorizationService}. - */ - public InMemoryOAuth2AuthorizationService() { - this.authorizations = new CopyOnWriteArrayList<>(); - } - - /** - * Constructs an {@code InMemoryOAuth2AuthorizationService} using the provided parameters. - * - * @param authorizations the initial {@code List} of {@link OAuth2Authorization}(s) - */ - public InMemoryOAuth2AuthorizationService(List authorizations) { - Assert.notEmpty(authorizations, "authorizations cannot be empty"); - this.authorizations = new CopyOnWriteArrayList<>(authorizations); - } + private final Map authorizations = new ConcurrentHashMap<>(); @Override public void save(OAuth2Authorization authorization) { Assert.notNull(authorization, "authorization cannot be null"); - this.authorizations.add(authorization); + OAuth2AuthorizationId authorizationId = new OAuth2AuthorizationId( + authorization.getRegisteredClientId(), authorization.getPrincipalName()); + this.authorizations.put(authorizationId, authorization); } @Override public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) { Assert.hasText(token, "token cannot be empty"); - return this.authorizations.stream() + return this.authorizations.values().stream() .filter(authorization -> hasToken(authorization, token, tokenType)) .findFirst() .orElse(null); @@ -72,4 +61,33 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza } return false; } + + private static class OAuth2AuthorizationId implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID; + private final String registeredClientId; + private final String principalName; + + private OAuth2AuthorizationId(String registeredClientId, String principalName) { + this.registeredClientId = registeredClientId; + this.principalName = principalName; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OAuth2AuthorizationId that = (OAuth2AuthorizationId) obj; + return Objects.equals(this.registeredClientId, that.registeredClientId) && + Objects.equals(this.principalName, that.principalName); + } + + @Override + public int hashCode() { + return Objects.hash(this.registeredClientId, this.principalName); + } + } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index 3b19fed..6dabb77 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -22,7 +22,6 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import java.time.Instant; -import java.util.Collections; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -43,13 +42,6 @@ public class InMemoryOAuth2AuthorizationServiceTests { this.authorizationService = new InMemoryOAuth2AuthorizationService(); } - @Test - public void constructorWhenAuthorizationListNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new InMemoryOAuth2AuthorizationService(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizations cannot be empty"); - } - @Test public void saveWhenAuthorizationNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizationService.save(null)) @@ -83,7 +75,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { .principalName(PRINCIPAL_NAME) .attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE) .build(); - this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization)); + this.authorizationService.save(authorization); OAuth2Authorization result = this.authorizationService.findByToken( AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);