diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index cdcbfd5..b78f92e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza this.authorizations.remove(authorizationId, authorization); } + @Nullable @Override public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) { Assert.hasText(token, "token cannot be empty"); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java index 1cc2b0f..34293a1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ public interface OAuth2AuthorizationService { * @param tokenType the {@link TokenType token type} * @return the {@link OAuth2Authorization} if found, otherwise {@code null} */ + @Nullable OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java index c748c01..9127d31 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,8 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; @@ -63,18 +61,8 @@ public class OAuth2TokenRevocationAuthenticationProvider implements Authenticati getAuthenticatedClientElseThrowInvalidClient(tokenRevocationAuthentication); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); - TokenType tokenType = null; - String tokenTypeHint = tokenRevocationAuthentication.getTokenTypeHint(); - if (StringUtils.hasText(tokenTypeHint)) { - if (TokenType.REFRESH_TOKEN.getValue().equals(tokenTypeHint)) { - tokenType = TokenType.REFRESH_TOKEN; - } else if (TokenType.ACCESS_TOKEN.getValue().equals(tokenTypeHint)) { - tokenType = TokenType.ACCESS_TOKEN; - } - } - OAuth2Authorization authorization = this.authorizationService.findByToken( - tokenRevocationAuthentication.getToken(), tokenType); + tokenRevocationAuthentication.getToken(), null); if (authorization == null) { // Return the authentication request when token not found return tokenRevocationAuthentication; diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java index bf9967b..5325221 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java @@ -55,6 +55,7 @@ import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; @@ -102,7 +103,7 @@ public class OAuth2TokenRevocationTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2RefreshToken token = authorization.getTokens().getRefreshToken(); TokenType tokenType = TokenType.REFRESH_TOKEN; - when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization); + when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization); this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI) .params(getTokenRevocationRequestParameters(token, tokenType)) @@ -111,7 +112,7 @@ public class OAuth2TokenRevocationTests { .andExpect(status().isOk()); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); - verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType)); + verify(authorizationService).findByToken(eq(token.getTokenValue()), isNull()); ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(authorizationService).save(authorizationCaptor.capture()); @@ -134,7 +135,7 @@ public class OAuth2TokenRevocationTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2AccessToken token = authorization.getTokens().getAccessToken(); TokenType tokenType = TokenType.ACCESS_TOKEN; - when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization); + when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization); this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI) .params(getTokenRevocationRequestParameters(token, tokenType)) @@ -143,7 +144,7 @@ public class OAuth2TokenRevocationTests { .andExpect(status().isOk()); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); - verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType)); + verify(authorizationService).findByToken(eq(token.getTokenValue()), isNull()); ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(authorizationService).save(authorizationCaptor.capture()); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java index aeb8392..5629aa3 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -115,7 +116,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests { TestRegisteredClients.registeredClient2().build()).build(); when(this.authorizationService.findByToken( eq("token"), - eq(TokenType.ACCESS_TOKEN))) + isNull())) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); @@ -136,7 +137,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests { registeredClient).build(); when(this.authorizationService.findByToken( eq(authorization.getTokens().getRefreshToken().getTokenValue()), - eq(TokenType.REFRESH_TOKEN))) + isNull())) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); @@ -164,7 +165,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests { registeredClient).build(); when(this.authorizationService.findByToken( eq(authorization.getTokens().getAccessToken().getTokenValue()), - eq(TokenType.ACCESS_TOKEN))) + isNull())) .thenReturn(authorization); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);