Improve RegisteredClient model

Closes gh-221
This commit is contained in:
Joe Grandja 2021-02-09 15:49:49 -05:00
parent 313b4cc5d3
commit afd5491ced
4 changed files with 87 additions and 45 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,13 +15,14 @@
*/ */
package org.springframework.security.oauth2.server.authorization.client; package org.springframework.security.oauth2.server.authorization.client;
import org.springframework.util.Assert;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/** /**
* A {@link RegisteredClientRepository} that stores {@link RegisteredClient}(s) in-memory. * A {@link RegisteredClientRepository} that stores {@link RegisteredClient}(s) in-memory.
* *
@ -74,12 +75,14 @@ public final class InMemoryRegisteredClientRepository implements RegisteredClien
this.clientIdRegistrationMap = clientIdRegistrationMapResult; this.clientIdRegistrationMap = clientIdRegistrationMapResult;
} }
@Nullable
@Override @Override
public RegisteredClient findById(String id) { public RegisteredClient findById(String id) {
Assert.hasText(id, "id cannot be empty"); Assert.hasText(id, "id cannot be empty");
return this.idRegistrationMap.get(id); return this.idRegistrationMap.get(id);
} }
@Nullable
@Override @Override
public RegisteredClient findByClientId(String clientId) { public RegisteredClient findByClientId(String clientId) {
Assert.hasText(clientId, "clientId cannot be empty"); Assert.hasText(clientId, "clientId cannot be empty");

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,22 +15,23 @@
*/ */
package org.springframework.security.oauth2.server.authorization.client; package org.springframework.security.oauth2.server.authorization.client;
import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import java.io.Serializable; import java.io.Serializable;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashSet; import java.util.HashSet;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/** /**
* A representation of a client registration with an OAuth 2.0 Authorization Server. * A representation of a client registration with an OAuth 2.0 Authorization Server.
* *
@ -82,8 +83,7 @@ public class RegisteredClient implements Serializable {
} }
/** /**
* Returns the {@link ClientAuthenticationMethod authentication method(s)} used * Returns the {@link ClientAuthenticationMethod authentication method(s)} that the client may use.
* when authenticating the client with the authorization server.
* *
* @return the {@code Set} of {@link ClientAuthenticationMethod authentication method(s)} * @return the {@code Set} of {@link ClientAuthenticationMethod authentication method(s)}
*/ */
@ -110,7 +110,7 @@ public class RegisteredClient implements Serializable {
} }
/** /**
* Returns the scope(s) used by the client. * Returns the scope(s) that the client may use.
* *
* @return the {@code Set} of scope(s) * @return the {@code Set} of scope(s)
*/ */
@ -136,6 +136,33 @@ public class RegisteredClient implements Serializable {
return this.tokenSettings; return this.tokenSettings;
} }
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RegisteredClient that = (RegisteredClient) obj;
return Objects.equals(this.id, that.id) &&
Objects.equals(this.clientId, that.clientId) &&
Objects.equals(this.clientSecret, that.clientSecret) &&
Objects.equals(this.clientAuthenticationMethods, that.clientAuthenticationMethods) &&
Objects.equals(this.authorizationGrantTypes, that.authorizationGrantTypes) &&
Objects.equals(this.redirectUris, that.redirectUris) &&
Objects.equals(this.scopes, that.scopes) &&
Objects.equals(this.clientSettings.settings(), that.getClientSettings().settings()) &&
Objects.equals(this.tokenSettings.settings(), that.tokenSettings.settings());
}
@Override
public int hashCode() {
return Objects.hash(this.id, this.clientId, this.clientSecret,
this.clientAuthenticationMethods, this.authorizationGrantTypes, this.redirectUris,
this.scopes, this.clientSettings.settings(), this.tokenSettings.settings());
}
@Override @Override
public String toString() { public String toString() {
return "RegisteredClient {" + return "RegisteredClient {" +
@ -145,6 +172,8 @@ public class RegisteredClient implements Serializable {
", authorizationGrantTypes=" + this.authorizationGrantTypes + ", authorizationGrantTypes=" + this.authorizationGrantTypes +
", redirectUris=" + this.redirectUris + ", redirectUris=" + this.redirectUris +
", scopes=" + this.scopes + ", scopes=" + this.scopes +
", clientSettings=" + this.clientSettings.settings() +
", tokenSettings=" + this.tokenSettings.settings() +
'}'; '}';
} }
@ -160,12 +189,12 @@ public class RegisteredClient implements Serializable {
} }
/** /**
* Returns a new {@link Builder}, initialized with the provided {@link RegisteredClient}. * Returns a new {@link Builder}, initialized with the values from the provided {@link RegisteredClient}.
* *
* @param registeredClient the {@link RegisteredClient} to copy from * @param registeredClient the {@link RegisteredClient} used for initializing the {@link Builder}
* @return the {@link Builder} * @return the {@link Builder}
*/ */
public static Builder withRegisteredClient(RegisteredClient registeredClient) { public static Builder from(RegisteredClient registeredClient) {
Assert.notNull(registeredClient, "registeredClient cannot be null"); Assert.notNull(registeredClient, "registeredClient cannot be null");
return new Builder(registeredClient); return new Builder(registeredClient);
} }
@ -178,10 +207,10 @@ public class RegisteredClient implements Serializable {
private String id; private String id;
private String clientId; private String clientId;
private String clientSecret; private String clientSecret;
private Set<ClientAuthenticationMethod> clientAuthenticationMethods = new LinkedHashSet<>(); private Set<ClientAuthenticationMethod> clientAuthenticationMethods = new HashSet<>();
private Set<AuthorizationGrantType> authorizationGrantTypes = new LinkedHashSet<>(); private Set<AuthorizationGrantType> authorizationGrantTypes = new HashSet<>();
private Set<String> redirectUris = new LinkedHashSet<>(); private Set<String> redirectUris = new HashSet<>();
private Set<String> scopes = new LinkedHashSet<>(); private Set<String> scopes = new HashSet<>();
private ClientSettings clientSettings = new ClientSettings(); private ClientSettings clientSettings = new ClientSettings();
private TokenSettings tokenSettings = new TokenSettings(); private TokenSettings tokenSettings = new TokenSettings();
@ -385,13 +414,16 @@ public class RegisteredClient implements Serializable {
registeredClient.id = this.id; registeredClient.id = this.id;
registeredClient.clientId = this.clientId; registeredClient.clientId = this.clientId;
registeredClient.clientSecret = this.clientSecret; registeredClient.clientSecret = this.clientSecret;
registeredClient.clientAuthenticationMethods = registeredClient.clientAuthenticationMethods = Collections.unmodifiableSet(
Collections.unmodifiableSet(this.clientAuthenticationMethods); new HashSet<>(this.clientAuthenticationMethods));
registeredClient.authorizationGrantTypes = Collections.unmodifiableSet(this.authorizationGrantTypes); registeredClient.authorizationGrantTypes = Collections.unmodifiableSet(
registeredClient.redirectUris = Collections.unmodifiableSet(this.redirectUris); new HashSet<>(this.authorizationGrantTypes));
registeredClient.scopes = Collections.unmodifiableSet(this.scopes); registeredClient.redirectUris = Collections.unmodifiableSet(
registeredClient.clientSettings = this.clientSettings; new HashSet<>(this.redirectUris));
registeredClient.tokenSettings = this.tokenSettings; registeredClient.scopes = Collections.unmodifiableSet(
new HashSet<>(this.scopes));
registeredClient.clientSettings = new ClientSettings(this.clientSettings.settings());
registeredClient.tokenSettings = new TokenSettings(this.tokenSettings.settings());
return registeredClient; return registeredClient;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,6 +15,8 @@
*/ */
package org.springframework.security.oauth2.server.authorization.client; package org.springframework.security.oauth2.server.authorization.client;
import org.springframework.lang.Nullable;
/** /**
* A repository for OAuth 2.0 {@link RegisteredClient}(s). * A repository for OAuth 2.0 {@link RegisteredClient}(s).
* *
@ -26,19 +28,23 @@ package org.springframework.security.oauth2.server.authorization.client;
public interface RegisteredClientRepository { public interface RegisteredClientRepository {
/** /**
* Returns the registered client identified by the provided {@code id}, or {@code null} if not found. * Returns the registered client identified by the provided {@code id},
* or {@code null} if not found.
* *
* @param id the registration identifier * @param id the registration identifier
* @return the {@link RegisteredClient} if found, otherwise {@code null} * @return the {@link RegisteredClient} if found, otherwise {@code null}
*/ */
@Nullable
RegisteredClient findById(String id); RegisteredClient findById(String id);
/** /**
* Returns the registered client identified by the provided {@code clientId}, or {@code null} if not found. * Returns the registered client identified by the provided {@code clientId},
* or {@code null} if not found.
* *
* @param clientId the client identifier * @param clientId the client identifier
* @return the {@link RegisteredClient} if found, otherwise {@code null} * @return the {@link RegisteredClient} if found, otherwise {@code null}
*/ */
@Nullable
RegisteredClient findByClientId(String clientId); RegisteredClient findByClientId(String clientId);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,15 +15,16 @@
*/ */
package org.springframework.security.oauth2.server.authorization.client; package org.springframework.security.oauth2.server.authorization.client;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import java.util.Collections; import java.util.Collections;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -231,7 +232,7 @@ public class RegisteredClientTests {
.build(); .build();
assertThat(registration.getAuthorizationGrantTypes()) assertThat(registration.getAuthorizationGrantTypes())
.containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS); .containsExactlyInAnyOrder(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS);
} }
@Test @Test
@ -249,7 +250,7 @@ public class RegisteredClientTests {
.build(); .build();
assertThat(registration.getAuthorizationGrantTypes()) assertThat(registration.getAuthorizationGrantTypes())
.containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS); .containsExactlyInAnyOrder(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS);
} }
@Test @Test
@ -280,7 +281,7 @@ public class RegisteredClientTests {
.build(); .build();
assertThat(registration.getClientAuthenticationMethods()) assertThat(registration.getClientAuthenticationMethods())
.containsExactly(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST); .containsExactlyInAnyOrder(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST);
} }
@Test @Test
@ -298,7 +299,7 @@ public class RegisteredClientTests {
.build(); .build();
assertThat(registration.getClientAuthenticationMethods()) assertThat(registration.getClientAuthenticationMethods())
.containsExactly(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST); .containsExactlyInAnyOrder(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST);
} }
@Test @Test
@ -320,7 +321,7 @@ public class RegisteredClientTests {
@Test @Test
public void buildWhenRegisteredClientProvidedThenMakesACopy() { public void buildWhenRegisteredClientProvidedThenMakesACopy() {
RegisteredClient registration = TestRegisteredClients.registeredClient().build(); RegisteredClient registration = TestRegisteredClients.registeredClient().build();
RegisteredClient updated = RegisteredClient.withRegisteredClient(registration).build(); RegisteredClient updated = RegisteredClient.from(registration).build();
assertThat(registration.getId()).isEqualTo(updated.getId()); assertThat(registration.getId()).isEqualTo(updated.getId());
assertThat(registration.getClientId()).isEqualTo(updated.getClientId()); assertThat(registration.getClientId()).isEqualTo(updated.getClientId());
@ -345,7 +346,7 @@ public class RegisteredClientTests {
String newSecret = "new-secret"; String newSecret = "new-secret";
String newScope = "new-scope"; String newScope = "new-scope";
String newRedirectUri = "https://another-redirect-uri.com"; String newRedirectUri = "https://another-redirect-uri.com";
RegisteredClient updated = RegisteredClient.withRegisteredClient(registration) RegisteredClient updated = RegisteredClient.from(registration)
.clientSecret(newSecret) .clientSecret(newSecret)
.scopes(scopes -> { .scopes(scopes -> {
scopes.clear(); scopes.clear();