From 8700ff19df6fd5fe675c06b7a33e5d7931d01467 Mon Sep 17 00:00:00 2001 From: Anoop Garlapati Date: Sun, 26 Apr 2020 00:55:03 +0530 Subject: [PATCH] Add implementation for in memory RegisteredClientRepository Fixes gh-40 --- core/spring-authorization-server-core.gradle | 4 + .../oauth2/server/authorization/Version.java | 28 ++ .../InMemoryRegisteredClientRepository.java | 84 ++++ .../client/RegisteredClient.java | 351 ++++++++++++++- .../client/RegisteredClientRepository.java | 16 + ...MemoryRegisteredClientRepositoryTests.java | 115 +++++ .../client/RegisteredClientTests.java | 410 ++++++++++++++++++ .../client/TestRegisteredClients.java | 49 +++ 8 files changed, 1054 insertions(+), 3 deletions(-) create mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/Version.java create mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java diff --git a/core/spring-authorization-server-core.gradle b/core/spring-authorization-server-core.gradle index e4235be..d646ba0 100644 --- a/core/spring-authorization-server-core.gradle +++ b/core/spring-authorization-server-core.gradle @@ -18,3 +18,7 @@ dependencies { provided 'javax.servlet:javax.servlet-api' } + +jacoco { + toolVersion = '0.8.5' +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/Version.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/Version.java new file mode 100644 index 0000000..defa5b4 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/Version.java @@ -0,0 +1,28 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +/** + * Internal class used for serialization across Spring Security Authorization Server classes. + * + * @author Anoop Garlapati + */ +public class Version { + /** + * Global Serialization value for Spring Security Authorization Server classes. + */ + public static final long SERIAL_VERSION_UID = "0.0.1".hashCode(); +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java new file mode 100644 index 0000000..577816e --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.client; + +import org.springframework.util.Assert; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A {@link RegisteredClientRepository} that stores {@link RegisteredClient}(s) in-memory. + * + * @author Anoop Garlapati + * @see RegisteredClientRepository + * @see RegisteredClient + */ +public final class InMemoryRegisteredClientRepository implements RegisteredClientRepository { + private final Map idRegistrationMap; + private final Map clientIdRegistrationMap; + + /** + * Constructs an {@code InMemoryRegisteredClientRepository} using the provided parameters. + * + * @param registrations the client registration(s) + */ + public InMemoryRegisteredClientRepository(RegisteredClient... registrations) { + this(Arrays.asList(registrations)); + } + + /** + * Constructs an {@code InMemoryRegisteredClientRepository} using the provided parameters. + * + * @param registrations the client registration(s) + */ + public InMemoryRegisteredClientRepository(List registrations) { + Assert.notEmpty(registrations, "registrations cannot be empty"); + ConcurrentHashMap idRegistrationMapResult = new ConcurrentHashMap<>(); + ConcurrentHashMap clientIdRegistrationMapResult = new ConcurrentHashMap<>(); + for (RegisteredClient registration : registrations) { + Assert.notNull(registration, "registration cannot be null"); + String id = registration.getId(); + if (idRegistrationMapResult.containsKey(id)) { + throw new IllegalArgumentException("Registered client must be unique. " + + "Found duplicate identifier: " + id); + } + String clientId = registration.getClientId(); + if (clientIdRegistrationMapResult.containsKey(clientId)) { + throw new IllegalArgumentException("Registered client must be unique. " + + "Found duplicate client identifier: " + clientId); + } + idRegistrationMapResult.put(id, registration); + clientIdRegistrationMapResult.put(clientId, registration); + } + idRegistrationMap = idRegistrationMapResult; + clientIdRegistrationMap = clientIdRegistrationMapResult; + } + + @Override + public RegisteredClient findById(String id) { + Assert.hasText(id, "id cannot be empty"); + return this.idRegistrationMap.get(id); + } + + @Override + public RegisteredClient findByClientId(String clientId) { + Assert.hasText(clientId, "clientId cannot be empty"); + return this.clientIdRegistrationMap.get(clientId); + } +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java index 807a396..4f24b3c 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java @@ -15,25 +15,370 @@ */ package org.springframework.security.oauth2.server.authorization.client; -import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.server.authorization.Version; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import java.io.Serializable; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; +import java.util.function.Consumer; /** + * A representation of a client registration with an OAuth 2.0 Authorization Server. + * * @author Joe Grandja + * @author Anoop Garlapati + * @see Section 2 Client Registration */ public class RegisteredClient implements Serializable { - private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String id; private String clientId; private String clientSecret; - private Set clientAuthenticationMethods = Collections.emptySet(); + private Set clientAuthenticationMethods = + Collections.singleton(ClientAuthenticationMethod.BASIC); private Set authorizationGrantTypes = Collections.emptySet(); private Set redirectUris = Collections.emptySet(); private Set scopes = Collections.emptySet(); + protected RegisteredClient() { + } + + /** + * Returns the identifier for the registration. + * + * @return the identifier for the registration + */ + public String getId() { + return this.id; + } + + /** + * Returns the client identifier. + * + * @return the client identifier + */ + public String getClientId() { + return this.clientId; + } + + /** + * Returns the client secret. + * + * @return the client secret + */ + public String getClientSecret() { + return this.clientSecret; + } + + /** + * Returns the {@link ClientAuthenticationMethod authentication method(s)} used + * when authenticating the client with the authorization server. + * + * @return the {@code Set} of {@link ClientAuthenticationMethod authentication method(s)} + */ + public Set getClientAuthenticationMethods() { + return this.clientAuthenticationMethods; + } + + /** + * Returns the {@link AuthorizationGrantType authorization grant type(s)} that the client may use. + * + * @return the {@code Set} of {@link AuthorizationGrantType authorization grant type(s)} + */ + public Set getAuthorizationGrantTypes() { + return this.authorizationGrantTypes; + } + + /** + * Returns the redirect URI(s) that the client may use in redirect-based flows. + * + * @return the {@code Set} of redirect URI(s) + */ + public Set getRedirectUris() { + return this.redirectUris; + } + + /** + * Returns the scope(s) used by the client. + * + * @return the {@code Set} of scope(s) + */ + public Set getScopes() { + return this.scopes; + } + + @Override + public String toString() { + return "RegisteredClient{" + + "id='" + this.id + '\'' + + ", clientId='" + this.clientId + '\'' + + ", clientAuthenticationMethods=" + this.clientAuthenticationMethods + + ", authorizationGrantTypes=" + this.authorizationGrantTypes + + ", redirectUris=" + this.redirectUris + + ", scopes=" + this.scopes + + '}'; + } + + /** + * Returns a new {@link Builder}, initialized with the provided registration identifier. + * + * @param id the identifier for the registration + * @return the {@link Builder} + */ + public static Builder withId(String id) { + Assert.hasText(id, "id cannot be empty"); + return new Builder(id); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@link RegisteredClient}. + * + * @param registeredClient the {@link RegisteredClient} to copy from + * @return the {@link Builder} + */ + public static Builder withRegisteredClient(RegisteredClient registeredClient) { + Assert.notNull(registeredClient, "registeredClient cannot be null"); + return new Builder(registeredClient); + } + + /** + * A builder for {@link RegisteredClient}. + */ + public static class Builder implements Serializable { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private String id; + private String clientId; + private String clientSecret; + private Set clientAuthenticationMethods = + new LinkedHashSet<>(Collections.singletonList(ClientAuthenticationMethod.BASIC)); + private Set authorizationGrantTypes = new LinkedHashSet<>(); + private Set redirectUris = new LinkedHashSet<>(); + private Set scopes = new LinkedHashSet<>(); + + protected Builder(String id) { + this.id = id; + } + + protected Builder(RegisteredClient registeredClient) { + this.id = registeredClient.id; + this.clientId = registeredClient.clientId; + this.clientSecret = registeredClient.clientSecret; + this.clientAuthenticationMethods = registeredClient.clientAuthenticationMethods == null ? null : + new HashSet<>(registeredClient.clientAuthenticationMethods); + this.authorizationGrantTypes = registeredClient.authorizationGrantTypes == null ? null : + new HashSet<>(registeredClient.authorizationGrantTypes); + this.redirectUris = registeredClient.redirectUris == null ? null : + new HashSet<>(registeredClient.redirectUris); + this.scopes = registeredClient.scopes == null ? null : new HashSet<>(registeredClient.scopes); + } + + /** + * Sets the identifier for the registration. + * + * @param id the identifier for the registration + * @return the {@link Builder} + */ + public Builder id(String id) { + this.id = id; + return this; + } + + /** + * Sets the client identifier. + * + * @param clientId the client identifier + * @return the {@link Builder} + */ + public Builder clientId(String clientId) { + this.clientId = clientId; + return this; + } + + /** + * Sets the client secret. + * + * @param clientSecret the client secret + * @return the {@link Builder} + */ + public Builder clientSecret(String clientSecret) { + this.clientSecret = clientSecret; + return this; + } + + /** + * Adds the {@link ClientAuthenticationMethod authentication method} to the set of + * client authentication methods used when authenticating the client with the authorization server. + * + * @param clientAuthenticationMethod the authentication method + * @return the {@link Builder} + */ + public Builder clientAuthenticationMethod(ClientAuthenticationMethod clientAuthenticationMethod) { + this.clientAuthenticationMethods.add(clientAuthenticationMethod); + return this; + } + + /** + * Sets the {@link ClientAuthenticationMethod authentication method(s)} used + * when authenticating the client with the authorization server. + * + * @param clientAuthenticationMethodsConsumer the authentication method(s) {@link Consumer} + * @return the {@link Builder} + */ + public Builder clientAuthenticationMethods( + Consumer> clientAuthenticationMethodsConsumer) { + clientAuthenticationMethodsConsumer.accept(this.clientAuthenticationMethods); + return this; + } + + /** + * Adds the {@link AuthorizationGrantType authorization grant type} to + * the set of authorization grant types that the client may use. + * + * @param authorizationGrantType the authorization grant type + * @return the {@link Builder} + */ + public Builder authorizationGrantType(AuthorizationGrantType authorizationGrantType) { + this.authorizationGrantTypes.add(authorizationGrantType); + return this; + } + + /** + * Sets the {@link AuthorizationGrantType authorization grant type(s)} that the client may use. + * + * @param authorizationGrantTypesConsumer the authorization grant type(s) {@link Consumer} + * @return the {@link Builder} + */ + public Builder authorizationGrantTypes(Consumer> authorizationGrantTypesConsumer) { + authorizationGrantTypesConsumer.accept(this.authorizationGrantTypes); + return this; + } + + /** + * Adds the redirect URI to the set of redirect URIs that the client may use in redirect-based flows. + * + * @param redirectUri the redirect URI to add + * @return the {@link Builder} + */ + public Builder redirectUri(String redirectUri) { + this.redirectUris.add(redirectUri); + return this; + } + + /** + * Sets the redirect URI(s) that the client may use in redirect-based flows. + * + * @param redirectUrisConsumer the redirect URI(s) {@link Consumer} + * @return the {@link Builder} + */ + public Builder redirectUris(Consumer> redirectUrisConsumer) { + redirectUrisConsumer.accept(this.redirectUris); + return this; + } + + /** + * Adds the scope to the set of scopes used by the client. + * + * @param scope the scope to add + * @return the {@link Builder} + */ + public Builder scope(String scope) { + this.scopes.add(scope); + return this; + } + + /** + * Sets the scope(s) used by the client. + * + * @param scopesConsumer the scope(s) {@link Consumer} + * @return the {@link Builder} + */ + public Builder scopes(Consumer> scopesConsumer) { + scopesConsumer.accept(this.scopes); + return this; + } + + /** + * Builds a new {@link RegisteredClient}. + * + * @return a {@link RegisteredClient} + */ + public RegisteredClient build() { + Assert.notEmpty(this.clientAuthenticationMethods, "clientAuthenticationMethods cannot be empty"); + Assert.notEmpty(this.authorizationGrantTypes, "authorizationGrantTypes cannot be empty"); + if (authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { + Assert.hasText(this.id, "id cannot be empty"); + Assert.hasText(this.clientId, "clientId cannot be empty"); + Assert.hasText(this.clientSecret, "clientSecret cannot be empty"); + Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty"); + } + this.validateScopes(); + this.validateRedirectUris(); + return this.create(); + } + + private RegisteredClient create() { + RegisteredClient registeredClient = new RegisteredClient(); + + registeredClient.id = this.id; + registeredClient.clientId = this.clientId; + registeredClient.clientSecret = this.clientSecret; + registeredClient.clientAuthenticationMethods = + Collections.unmodifiableSet(this.clientAuthenticationMethods); + registeredClient.authorizationGrantTypes = Collections.unmodifiableSet(this.authorizationGrantTypes); + registeredClient.redirectUris = Collections.unmodifiableSet(this.redirectUris); + registeredClient.scopes = Collections.unmodifiableSet(this.scopes); + + return registeredClient; + } + + private void validateScopes() { + if (CollectionUtils.isEmpty(this.scopes)) { + return; + } + + for (String scope : this.scopes) { + Assert.isTrue(validateScope(scope), "scope \"" + scope + "\" contains invalid characters"); + } + } + + private static boolean validateScope(String scope) { + return scope == null || + scope.chars().allMatch(c -> withinTheRangeOf(c, 0x21, 0x21) || + withinTheRangeOf(c, 0x23, 0x5B) || + withinTheRangeOf(c, 0x5D, 0x7E)); + } + + private static boolean withinTheRangeOf(int c, int min, int max) { + return c >= min && c <= max; + } + + private void validateRedirectUris() { + if (CollectionUtils.isEmpty(this.redirectUris)) { + return; + } + + for (String redirectUri : redirectUris) { + Assert.isTrue(validateRedirectUri(redirectUri), + "redirect_uri \"" + redirectUri + "\" is not a valid redirect URI or contains fragment"); + } + } + + private static boolean validateRedirectUri(String redirectUri) { + try { + URI validRedirectUri = new URI(redirectUri); + return validRedirectUri.getFragment() == null; + } catch (URISyntaxException ex) { + return false; + } + } + } + } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java index 9e37d6b..4da1374 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java @@ -16,12 +16,28 @@ package org.springframework.security.oauth2.server.authorization.client; /** + * A repository for OAuth 2.0 {@link RegisteredClient}(s). + * * @author Joe Grandja + * @author Anoop Garlapati + * @see RegisteredClient */ public interface RegisteredClientRepository { + /** + * Returns the registered client identified by the provided {@code id}, or {@code null} if not found. + * + * @param id the registration identifier + * @return the {@link RegisteredClient} if found, otherwise {@code null} + */ RegisteredClient findById(String id); + /** + * Returns the registered client identified by the provided {@code clientId}, or {@code null} if not found. + * + * @param clientId the client identifier + * @return the {@link RegisteredClient} if found, otherwise {@code null} + */ RegisteredClient findByClientId(String clientId); } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java new file mode 100644 index 0000000..54cc8e6 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java @@ -0,0 +1,115 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.client; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link InMemoryRegisteredClientRepository}. + * + * @author Anoop Garlapati + */ +public class InMemoryRegisteredClientRepositoryTests { + private RegisteredClient registration = TestRegisteredClients.registeredClient().build(); + + private InMemoryRegisteredClientRepository clients = new InMemoryRegisteredClientRepository(this.registration); + + @Test + public void constructorVarargsRegisteredClientWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + RegisteredClient registration = null; + new InMemoryRegisteredClientRepository(registration); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorListRegisteredClientWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + List registrations = null; + new InMemoryRegisteredClientRepository(registrations); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorListClientRegistrationWhenEmptyThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + List registrations = Collections.emptyList(); + new InMemoryRegisteredClientRepository(registrations); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorListRegisteredClientWhenDuplicateIdThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + RegisteredClient anotherRegistrationWithSameId = TestRegisteredClients.registeredClient2() + .id(this.registration.getId()).build(); + List registrations = Arrays.asList(this.registration, anotherRegistrationWithSameId); + new InMemoryRegisteredClientRepository(registrations); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorListRegisteredClientWhenDuplicateClientIdThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + RegisteredClient anotherRegistrationWithSameClientId = TestRegisteredClients.registeredClient2() + .clientId(this.registration.getClientId()).build(); + List registrations = Arrays.asList(this.registration, + anotherRegistrationWithSameClientId); + new InMemoryRegisteredClientRepository(registrations); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void findByIdWhenFoundThenFound() { + String id = this.registration.getId(); + assertThat(clients.findById(id)).isEqualTo(this.registration); + } + + @Test + public void findByIdWhenNotFoundThenNull() { + String missingId = this.registration.getId() + "MISSING"; + assertThat(clients.findById(missingId)).isNull(); + } + + @Test + public void findByIdWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> clients.findById(null)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void findByClientIdWhenFoundThenFound() { + String clientId = this.registration.getClientId(); + assertThat(clients.findByClientId(clientId)).isEqualTo(this.registration); + } + + @Test + public void findByClientIdWhenNotFoundThenNull() { + String missingClientId = this.registration.getClientId() + "MISSING"; + assertThat(clients.findByClientId(missingClientId)).isNull(); + } + + @Test + public void findByClientIdWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> clients.findByClientId(null)).isInstanceOf(IllegalArgumentException.class); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java new file mode 100644 index 0000000..9875e86 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java @@ -0,0 +1,410 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.client; + +import org.junit.Test; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; + +import java.util.Collections; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link RegisteredClient}. + * + * @author Anoop Garlapati + */ +public class RegisteredClientTests { + private static final String ID = "registration-1"; + private static final String CLIENT_ID = "client-1"; + private static final String CLIENT_SECRET = "secret"; + private static final Set REDIRECT_URIS = Collections.singleton("https://example.com"); + private static final Set SCOPES = Collections.unmodifiableSet( + Stream.of("openid", "profile", "email").collect(Collectors.toSet())); + private static final Set CLIENT_AUTHENTICATION_METHODS = + Collections.singleton(ClientAuthenticationMethod.BASIC); + + @Test + public void buildWhenAuthorizationGrantTypesIsNotSetThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeGrantAllAttributesProvidedThenAllAttributesAreSet() { + RegisteredClient registration = RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getId()).isEqualTo(ID); + assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); + assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); + assertThat(registration.getAuthorizationGrantTypes()) + .isEqualTo(Collections.singleton(AuthorizationGrantType.AUTHORIZATION_CODE)); + assertThat(registration.getClientAuthenticationMethods()).isEqualTo(CLIENT_AUTHENTICATION_METHODS); + assertThat(registration.getRedirectUris()).isEqualTo(REDIRECT_URIS); + assertThat(registration.getScopes()).isEqualTo(SCOPES); + } + + @Test + public void buildWhenAuthorizationCodeGrantIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(null) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeGrantClientIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(null) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeGrantClientSecretIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(null) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeGrantRedirectUrisNotProvidedThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeGrantRedirectUrisConsumerClearsSetThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("https://example.com") + .redirectUris(Set::clear) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() { + RegisteredClient registration = RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getClientAuthenticationMethods()) + .isEqualTo(Collections.singleton(ClientAuthenticationMethod.BASIC)); + } + + @Test + public void buildWhenAuthorizationCodeGrantScopeIsEmptyThenScopeNotRequired() { + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .build(); + } + + @Test + public void buildWhenAuthorizationCodeGrantScopeConsumerIsProvidedThenConsumerAccepted() { + RegisteredClient registration = RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getScopes()).isEqualTo(SCOPES); + } + + @Test + public void buildWhenScopeContainsASpaceThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(null) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scope("openid profile") + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenScopesContainsAnInvalidCharacterThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scope("an\"invalid\"scope") + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenRedirectUrisContainInvalidUriThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("invalid URI") + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenRedirectUrisContainUriWithFragmentThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("https://example.com/page#fragment") + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenTwoAuthorizationGrantTypesAreProvidedThenBothAreRegistered() { + RegisteredClient registration = RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getAuthorizationGrantTypes()) + .containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS); + } + + @Test + public void buildWhenAuthorizationGrantTypesConsumerIsProvidedThenConsumerAccepted() { + RegisteredClient registration = RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantTypes(authorizationGrantTypes -> { + authorizationGrantTypes.add(AuthorizationGrantType.AUTHORIZATION_CODE); + authorizationGrantTypes.add(AuthorizationGrantType.CLIENT_CREDENTIALS); + }) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getAuthorizationGrantTypes()) + .containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS); + } + + @Test + public void buildWhenAuthorizationGrantTypesConsumerClearsSetThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantTypes(Set::clear) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenTwoClientAuthenticationMethodsAreProvidedThenBothAreRegistered() { + RegisteredClient registration = RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .clientAuthenticationMethod(ClientAuthenticationMethod.POST) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getClientAuthenticationMethods()) + .containsExactly(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST); + } + + @Test + public void buildWhenClientAuthenticationMethodsConsumerIsProvidedThenConsumerAccepted() { + RegisteredClient registration = RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethods(clientAuthenticationMethods -> { + clientAuthenticationMethods.add(ClientAuthenticationMethod.BASIC); + clientAuthenticationMethods.add(ClientAuthenticationMethod.POST); + }) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getClientAuthenticationMethods()) + .containsExactly(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST); + } + + @Test + public void buildWhenClientAuthenticationMethodsConsumerClearsSetThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> { + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethods(Set::clear) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + }).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenOverrideIdThenOverridden() { + String overriddenId = "override"; + RegisteredClient registration = RegisteredClient.withId(ID) + .id(overriddenId) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .scopes(scopes -> scopes.addAll(SCOPES)) + .build(); + + assertThat(registration.getId()).isEqualTo(overriddenId); + } + + @Test + public void buildWhenRegisteredClientProvidedThenMakesACopy() { + RegisteredClient registration = TestRegisteredClients.registeredClient().build(); + RegisteredClient updated = RegisteredClient.withRegisteredClient(registration).build(); + + assertThat(registration.getClientAuthenticationMethods()).isEqualTo(updated.getClientAuthenticationMethods()); + assertThat(registration.getClientAuthenticationMethods()).isNotSameAs(updated.getClientAuthenticationMethods()); + assertThat(registration.getAuthorizationGrantTypes()).isEqualTo(updated.getAuthorizationGrantTypes()); + assertThat(registration.getAuthorizationGrantTypes()).isNotSameAs(updated.getAuthorizationGrantTypes()); + assertThat(registration.getRedirectUris()).isEqualTo(updated.getRedirectUris()); + assertThat(registration.getRedirectUris()).isNotSameAs(updated.getRedirectUris()); + assertThat(registration.getScopes()).isEqualTo(updated.getScopes()); + assertThat(registration.getScopes()).isNotSameAs(updated.getScopes()); + } + + @Test + public void buildWhenRegisteredClientProvidedThenEachPropertyMatches() { + RegisteredClient registration = TestRegisteredClients.registeredClient().build(); + RegisteredClient updated = RegisteredClient.withRegisteredClient(registration).build(); + + assertThat(registration.getId()).isEqualTo(updated.getId()); + assertThat(registration.getClientId()).isEqualTo(updated.getClientId()); + assertThat(registration.getClientSecret()).isEqualTo(updated.getClientSecret()); + assertThat(registration.getClientAuthenticationMethods()).isEqualTo(updated.getClientAuthenticationMethods()); + assertThat(registration.getAuthorizationGrantTypes()).isEqualTo(updated.getAuthorizationGrantTypes()); + assertThat(registration.getRedirectUris()).isEqualTo(updated.getRedirectUris()); + assertThat(registration.getScopes()).isEqualTo(updated.getScopes()); + } + + @Test + public void buildWhenClientRegistrationValuesOverriddenThenPropagated() { + RegisteredClient registration = TestRegisteredClients.registeredClient().build(); + String newSecret = "new-secret"; + String newScope = "new-scope"; + String newRedirectUri = "https://another-redirect-uri.com"; + RegisteredClient updated = RegisteredClient.withRegisteredClient(registration) + .clientSecret(newSecret) + .scopes(scopes -> { + scopes.clear(); + scopes.add(newScope); + }) + .redirectUris(redirectUris -> { + redirectUris.clear(); + redirectUris.add(newRedirectUri); + }) + .build(); + + assertThat(registration.getClientSecret()).isNotEqualTo(newSecret); + assertThat(updated.getClientSecret()).isEqualTo(newSecret); + assertThat(registration.getScopes()).doesNotContain(newScope); + assertThat(updated.getScopes()).containsExactly(newScope); + assertThat(registration.getRedirectUris()).doesNotContain(newRedirectUri); + assertThat(updated.getRedirectUris()).containsExactly(newRedirectUri); + } +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java new file mode 100644 index 0000000..5aa7bc0 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -0,0 +1,49 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.client; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; + +/** + * @author Anoop Garlapati + */ +public class TestRegisteredClients { + + public static RegisteredClient.Builder registeredClient() { + return RegisteredClient.withId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("https://example.com") + .scope("openid") + .scope("profile") + .scope("email"); + } + + public static RegisteredClient.Builder registeredClient2() { + return RegisteredClient.withId("registration-2") + .clientId("client-2") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("https://example.com") + .scope("openid") + .scope("profile") + .scope("email"); + } +}