diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java new file mode 100644 index 0000000..b9f161b --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -0,0 +1,83 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.springframework.util.Assert; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +/** + * In-memory implementation of {@link OAuth2AuthorizationService}. + * + * @author Krisztian Toth + */ +public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService { + private final List authorizations; + + /** + * Creates an {@link InMemoryOAuth2AuthorizationService}. + */ + public InMemoryOAuth2AuthorizationService() { + this(Collections.emptyList()); + } + + /** + * Creates an {@link InMemoryOAuth2AuthorizationService} with the provided {@link List}<{@link OAuth2Authorization}> + * as the in-memory store. + * + * @param authorizations a {@link List}<{@link OAuth2Authorization}> object to use as the store + */ + public InMemoryOAuth2AuthorizationService(List authorizations) { + Assert.notNull(authorizations, "authorizations cannot be null"); + this.authorizations = new CopyOnWriteArrayList<>(authorizations); + } + + @Override + public void save(OAuth2Authorization authorization) { + Assert.notNull(authorization, "authorization cannot be null"); + this.authorizations.add(authorization); + } + + @Override + public OAuth2Authorization findByTokenAndTokenType(String token, TokenType tokenType) { + Assert.hasText(token, "token cannot be empty"); + Assert.notNull(tokenType, "tokenType cannot be null"); + return this.authorizations.stream() + .filter(authorization -> doesMatch(authorization, token, tokenType)) + .findFirst() + .orElse(null); + + } + + private boolean doesMatch(OAuth2Authorization authorization, String token, TokenType tokenType) { + if (tokenType.equals(TokenType.ACCESS_TOKEN)) { + return isAccessTokenEqual(token, authorization); + } else if (tokenType.equals(TokenType.AUTHORIZATION_CODE)) { + return isAuthorizationCodeEqual(token, authorization); + } + return false; + } + + private boolean isAccessTokenEqual(String token, OAuth2Authorization authorization) { + return authorization.getAccessToken() != null && token.equals(authorization.getAccessToken().getTokenValue()); + } + + private boolean isAuthorizationCodeEqual(String token, OAuth2Authorization authorization) { + return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue())); + } +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java index e49d53a..068d14b 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java @@ -16,11 +16,19 @@ package org.springframework.security.oauth2.server.authorization; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.util.Assert; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; /** + * Represents a collection of attributes which describe an OAuth 2.0 authorization context. + * * @author Joe Grandja + * @author Krisztian Toth */ public class OAuth2Authorization { private String registeredClientId; @@ -28,4 +36,176 @@ public class OAuth2Authorization { private OAuth2AccessToken accessToken; private Map attributes; + protected OAuth2Authorization() { + } + + public String getRegisteredClientId() { + return this.registeredClientId; + } + + public String getPrincipalName() { + return this.principalName; + } + + public OAuth2AccessToken getAccessToken() { + return this.accessToken; + } + + public Map getAttributes() { + return this.attributes; + } + + /** + * Returns an attribute with the provided name or {@code null} if not found. + * + * @param name the name of the attribute + * @param the type of the attribute + * @return the found attribute or {@code null} + */ + public T getAttribute(String name) { + Assert.hasText(name, "name cannot be empty"); + return (T) this.attributes.get(name); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + OAuth2Authorization that = (OAuth2Authorization) o; + return Objects.equals(this.registeredClientId, that.registeredClientId) && + Objects.equals(this.principalName, that.principalName) && + Objects.equals(this.accessToken, that.accessToken) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() { + return Objects.hash(this.registeredClientId, this.principalName, this.accessToken, this.attributes); + } + + /** + * Returns an empty {@link Builder}. + * + * @return the {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@link OAuth2Authorization}. + * + * @param authorization the {@link OAuth2Authorization} to copy from + * @return the {@link Builder} + */ + public static Builder withAuthorization(OAuth2Authorization authorization) { + Assert.notNull(authorization, "authorization cannot be null"); + return new Builder(authorization); + } + + /** + * Builder class for {@link OAuth2Authorization}. + */ + public static class Builder { + private String registeredClientId; + private String principalName; + private OAuth2AccessToken accessToken; + private Map attributes = new HashMap<>(); + + protected Builder() { + } + + protected Builder(OAuth2Authorization authorization) { + this.registeredClientId = authorization.registeredClientId; + this.principalName = authorization.principalName; + this.accessToken = authorization.accessToken; + this.attributes = authorization.attributes; + } + + /** + * Sets the registered client identifier. + * + * @param registeredClientId the client id + * @return the {@link Builder} + */ + public Builder registeredClientId(String registeredClientId) { + this.registeredClientId = registeredClientId; + return this; + } + + /** + * Sets the principal name. + * + * @param principalName the principal name + * @return the {@link Builder} + */ + public Builder principalName(String principalName) { + this.principalName = principalName; + return this; + } + + /** + * Sets the {@link OAuth2AccessToken}. + * + * @param accessToken the {@link OAuth2AccessToken} + * @return the {@link Builder} + */ + public Builder accessToken(OAuth2AccessToken accessToken) { + this.accessToken = accessToken; + return this; + } + + /** + * Adds the attribute with the specified name and {@link String} value to the attributes map. + * + * @param name the name of the attribute + * @param value the value of the attribute + * @return the {@link Builder} + */ + public Builder attribute(String name, String value) { + Assert.hasText(name, "name cannot be empty"); + Assert.hasText(value, "value cannot be empty"); + this.attributes.put(name, value); + return this; + } + + /** + * A {@code Consumer} of the attributes map allowing to access or modify its content. + * + * @param attributesConsumer a {@link Consumer} of the attributes map + * @return the {@link Builder} + */ + public Builder attributes(Consumer> attributesConsumer) { + attributesConsumer.accept(this.attributes); + return this; + } + + /** + * Builds a new {@link OAuth2Authorization}. + * + * @return the {@link OAuth2Authorization} + */ + public OAuth2Authorization build() { + Assert.hasText(this.registeredClientId, "registeredClientId cannot be empty"); + Assert.hasText(this.principalName, "principalName cannot be empty"); + if (this.accessToken == null && this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()) == null) { + throw new IllegalArgumentException("either accessToken has to be set or the authorization code with key '" + + TokenType.AUTHORIZATION_CODE.getValue() + "' must be provided in the attributes map"); + } + return create(); + } + + private OAuth2Authorization create() { + OAuth2Authorization oAuth2Authorization = new OAuth2Authorization(); + oAuth2Authorization.registeredClientId = this.registeredClientId; + oAuth2Authorization.principalName = this.principalName; + oAuth2Authorization.accessToken = this.accessToken; + oAuth2Authorization.attributes = Collections.unmodifiableMap(this.attributes); + return oAuth2Authorization; + } + } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java new file mode 100644 index 0000000..53d3066 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.junit.Test; +import org.springframework.security.oauth2.core.OAuth2AccessToken; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link InMemoryOAuth2AuthorizationService}. + * + * @author Krisztian Toth + */ +public class InMemoryOAuth2AuthorizationServiceTests { + + private static final String TOKEN = "token"; + private static final TokenType AUTHORIZATION_CODE = TokenType.AUTHORIZATION_CODE; + private static final TokenType ACCESS_TOKEN = TokenType.ACCESS_TOKEN; + private static final Instant ISSUED_AT = Instant.now().minusSeconds(60); + private static final Instant EXPIRES_AT = Instant.now(); + + private InMemoryOAuth2AuthorizationService authorizationService; + + @Test + public void saveWhenAuthorizationProvidedThenSavedInList() { + authorizationService = new InMemoryOAuth2AuthorizationService(new ArrayList<>()); + + OAuth2Authorization authorization = OAuth2Authorization.builder() + .registeredClientId("clientId") + .principalName("principalName") + .attribute(AUTHORIZATION_CODE.getValue(), TOKEN) + .build(); + authorizationService.save(authorization); + + assertThat(authorizationService.findByTokenAndTokenType(TOKEN, AUTHORIZATION_CODE)).isEqualTo(authorization); + } + + @Test + public void saveWhenAuthorizationNotProvidedThenThrowIllegalArgumentException() { + authorizationService = new InMemoryOAuth2AuthorizationService(new ArrayList<>()); + + assertThatThrownBy(() -> authorizationService.save(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void findByTokenAndTokenTypeWhenTokenTypeIsAuthorizationCodeThenFound() { + OAuth2Authorization authorization = OAuth2Authorization.builder() + .registeredClientId("clientId") + .principalName("principalName") + .attribute(AUTHORIZATION_CODE.getValue(), TOKEN) + .build(); + authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization)); + + OAuth2Authorization result = authorizationService.findByTokenAndTokenType(TOKEN, TokenType.AUTHORIZATION_CODE); + assertThat(authorization).isEqualTo(result); + } + + @Test + public void findByTokenAndTokenTypeWhenTokenTypeIsAccessTokenThenFound() { + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, TOKEN, ISSUED_AT, + EXPIRES_AT); + OAuth2Authorization authorization = OAuth2Authorization.builder() + .registeredClientId("clientId") + .principalName("principalName") + .accessToken(accessToken) + .build(); + authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization)); + + OAuth2Authorization result = authorizationService.findByTokenAndTokenType(TOKEN, ACCESS_TOKEN); + assertThat(authorization).isEqualTo(result); + } + + @Test + public void findByTokenAndTokenTypeWhenTokenWithTokenTypeDoesNotExistThenNull() { + OAuth2Authorization authorization = OAuth2Authorization.builder() + .registeredClientId("clientId") + .principalName("principalName") + .attribute(AUTHORIZATION_CODE.getValue(), TOKEN) + .build(); + authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization)); + + OAuth2Authorization result = authorizationService.findByTokenAndTokenType(TOKEN, ACCESS_TOKEN); + assertThat(result).isNull(); + } + + @Test + public void findByTokenAndTokenTypeWhenTokenNullThenThrowIllegalArgumentException() { + authorizationService = new InMemoryOAuth2AuthorizationService(); + assertThatThrownBy(() -> authorizationService.findByTokenAndTokenType(null, TokenType.AUTHORIZATION_CODE)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void findByTokenAndTokenTypeWhenTokenTypeNullThenThrowIllegalArgumentException() { + authorizationService = new InMemoryOAuth2AuthorizationService(); + assertThatThrownBy(() -> authorizationService.findByTokenAndTokenType(TOKEN, null)) + .isInstanceOf(IllegalArgumentException.class); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java new file mode 100644 index 0000000..e7b48e8 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java @@ -0,0 +1,136 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.junit.Test; +import org.springframework.security.oauth2.core.OAuth2AccessToken; + +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests For {@link OAuth2Authorization}. + * + * @author Krisztian Toth + */ +public class OAuth2AuthorizationTests { + + public static final String REGISTERED_CLIENT_ID = "clientId"; + public static final String PRINCIPAL_NAME = "principal"; + public static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "token", Instant.now().minusSeconds(60), Instant.now()); + public static final String AUTHORIZATION_CODE_VALUE = TokenType.AUTHORIZATION_CODE.getValue(); + public static final String CODE = "code"; + public static final Map ATTRIBUTES = Collections.singletonMap(AUTHORIZATION_CODE_VALUE, CODE); + + @Test + public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() { + OAuth2Authorization authorization = OAuth2Authorization.builder() + .registeredClientId(REGISTERED_CLIENT_ID) + .principalName(PRINCIPAL_NAME) + .accessToken(ACCESS_TOKEN) + .attribute(AUTHORIZATION_CODE_VALUE, CODE) + .build(); + + assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT_ID); + assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); + assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN); + assertThat(authorization.getAttributes()).isEqualTo(ATTRIBUTES); + } + + @Test + public void buildWhenBuildThenImmutableMapIsCreated() { + OAuth2Authorization authorization = OAuth2Authorization.builder() + .registeredClientId(REGISTERED_CLIENT_ID) + .principalName(PRINCIPAL_NAME) + .accessToken(ACCESS_TOKEN) + .attribute("any", "value") + .build(); + + assertThatThrownBy(() -> authorization.getAttributes().put("any", "value")) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + public void buildWhenAccessTokenAndAuthorizationCodeNotProvidedThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + OAuth2Authorization.builder() + .registeredClientId(REGISTERED_CLIENT_ID) + .principalName(PRINCIPAL_NAME) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenRegisteredClientIdNotProvidedThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + OAuth2Authorization.builder() + .principalName(PRINCIPAL_NAME) + .accessToken(ACCESS_TOKEN) + .attribute(AUTHORIZATION_CODE_VALUE, CODE) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + OAuth2Authorization.builder() + .registeredClientId(REGISTERED_CLIENT_ID) + .accessToken(ACCESS_TOKEN) + .attribute(AUTHORIZATION_CODE_VALUE, CODE) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAttributeSetWithNullNameThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + OAuth2Authorization.builder() + .attribute(null, CODE) + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAttributeSetWithNullValueThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + OAuth2Authorization.builder() + .attribute(AUTHORIZATION_CODE_VALUE, null) + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void withOAuth2AuthorizationWhenAuthorizationProvidedThenAllAttributesAreCopied() { + OAuth2Authorization authorizationToCopy = OAuth2Authorization.builder() + .registeredClientId(REGISTERED_CLIENT_ID) + .principalName(PRINCIPAL_NAME) + .attribute(AUTHORIZATION_CODE_VALUE, CODE) + .build(); + + OAuth2Authorization authorization = OAuth2Authorization.withAuthorization(authorizationToCopy) + .accessToken(ACCESS_TOKEN) + .build(); + + assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT_ID); + assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME); + assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN); + assertThat(authorization.getAttributes()).isEqualTo(ATTRIBUTES); + } +}