diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 9dc13f9..ddc8a02 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -15,6 +15,7 @@ */ package org.springframework.security.oauth2.server.authorization; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import java.util.List; @@ -54,9 +55,8 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza } @Override - public OAuth2Authorization findByTokenAndTokenType(String token, TokenType tokenType) { + public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) { Assert.hasText(token, "token cannot be empty"); - Assert.notNull(tokenType, "tokenType cannot be null"); return this.authorizations.stream() .filter(authorization -> hasToken(authorization, token, tokenType)) .findFirst() diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java index de30714..0151e36 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java @@ -15,6 +15,8 @@ */ package org.springframework.security.oauth2.server.authorization; +import org.springframework.lang.Nullable; + /** * Implementations of this interface are responsible for the management * of {@link OAuth2Authorization OAuth 2.0 Authorization(s)}. @@ -40,6 +42,6 @@ public interface OAuth2AuthorizationService { * @param tokenType the {@link TokenType token type} * @return the {@link OAuth2Authorization} if found, otherwise {@code null} */ - OAuth2Authorization findByTokenAndTokenType(String token, TokenType tokenType); + OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType); } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 0afd01b..eb87295 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -89,7 +89,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica // from inadvertently accepting a code intended for a client with a different "client_id". // This protects the client from substitution of the authentication code. - OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType( + OAuth2Authorization authorization = this.authorizationService.findByToken( authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE); if (authorization == null) { throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index f9d79ea..3b19fed 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -65,25 +65,18 @@ public class InMemoryOAuth2AuthorizationServiceTests { .build(); this.authorizationService.save(expectedAuthorization); - OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType( + OAuth2Authorization authorization = this.authorizationService.findByToken( AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); assertThat(authorization).isEqualTo(expectedAuthorization); } @Test public void findByTokenAndTokenTypeWhenTokenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizationService.findByTokenAndTokenType(null, TokenType.AUTHORIZATION_CODE)) + assertThatThrownBy(() -> this.authorizationService.findByToken(null, TokenType.AUTHORIZATION_CODE)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("token cannot be empty"); } - @Test - public void findByTokenAndTokenTypeWhenTokenTypeNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizationService.findByTokenAndTokenType(AUTHORIZATION_CODE, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("tokenType cannot be null"); - } - @Test public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() { OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) @@ -92,7 +85,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { .build(); this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization)); - OAuth2Authorization result = this.authorizationService.findByTokenAndTokenType( + OAuth2Authorization result = this.authorizationService.findByToken( AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE); assertThat(authorization).isEqualTo(result); } @@ -108,14 +101,14 @@ public class InMemoryOAuth2AuthorizationServiceTests { .build(); this.authorizationService.save(authorization); - OAuth2Authorization result = this.authorizationService.findByTokenAndTokenType( + OAuth2Authorization result = this.authorizationService.findByToken( "access-token", TokenType.ACCESS_TOKEN); assertThat(authorization).isEqualTo(result); } @Test public void findByTokenAndTokenTypeWhenTokenDoesNotExistThenNull() { - OAuth2Authorization result = this.authorizationService.findByTokenAndTokenType( + OAuth2Authorization result = this.authorizationService.findByToken( "access-token", TokenType.ACCESS_TOKEN); assertThat(result).isNull(); } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 8187353..78ca0b3 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -119,7 +119,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenCodeIssuedToAnotherClientThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + when(this.authorizationService.findByToken(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( @@ -136,7 +136,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + when(this.authorizationService.findByToken(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); @@ -154,7 +154,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenValidCodeThenReturnAccessToken() { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); - when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) + when(this.authorizationService.findByToken(eq("code"), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient); diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationCodeGrantTests.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationCodeGrantTests.java index d8938c7..4c9b431 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationCodeGrantTests.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationCodeGrantTests.java @@ -136,7 +136,7 @@ public class OAuth2AuthorizationCodeGrantTests { .thenReturn(registeredClient); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); - when(authorizationService.findByTokenAndTokenType( + when(authorizationService.findByToken( eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), eq(TokenType.AUTHORIZATION_CODE))) .thenReturn(authorization); @@ -151,7 +151,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); - verify(authorizationService).findByTokenAndTokenType( + verify(authorizationService).findByToken( eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)), eq(TokenType.AUTHORIZATION_CODE)); verify(authorizationService).save(any());