OAuth2TokenRevocationAuthenticationProvider ignores token_type_hint

Closes gh-175
This commit is contained in:
Gerardo Roza 2021-01-06 19:52:22 -03:00 committed by Joe Grandja
parent 17c20e98d4
commit 4bcc1afac7
5 changed files with 16 additions and 24 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -54,6 +54,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
this.authorizations.remove(authorizationId, authorization); this.authorizations.remove(authorizationId, authorization);
} }
@Nullable
@Override @Override
public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) { public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) {
Assert.hasText(token, "token cannot be empty"); Assert.hasText(token, "token cannot be empty");

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -49,6 +49,7 @@ public interface OAuth2AuthorizationService {
* @param tokenType the {@link TokenType token type} * @param tokenType the {@link TokenType token type}
* @return the {@link OAuth2Authorization} if found, otherwise {@code null} * @return the {@link OAuth2Authorization} if found, otherwise {@code null}
*/ */
@Nullable
OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType); OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,10 +24,8 @@ import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
@ -63,18 +61,8 @@ public class OAuth2TokenRevocationAuthenticationProvider implements Authenticati
getAuthenticatedClientElseThrowInvalidClient(tokenRevocationAuthentication); getAuthenticatedClientElseThrowInvalidClient(tokenRevocationAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
TokenType tokenType = null;
String tokenTypeHint = tokenRevocationAuthentication.getTokenTypeHint();
if (StringUtils.hasText(tokenTypeHint)) {
if (TokenType.REFRESH_TOKEN.getValue().equals(tokenTypeHint)) {
tokenType = TokenType.REFRESH_TOKEN;
} else if (TokenType.ACCESS_TOKEN.getValue().equals(tokenTypeHint)) {
tokenType = TokenType.ACCESS_TOKEN;
}
}
OAuth2Authorization authorization = this.authorizationService.findByToken( OAuth2Authorization authorization = this.authorizationService.findByToken(
tokenRevocationAuthentication.getToken(), tokenType); tokenRevocationAuthentication.getToken(), null);
if (authorization == null) { if (authorization == null) {
// Return the authentication request when token not found // Return the authentication request when token not found
return tokenRevocationAuthentication; return tokenRevocationAuthentication;

View File

@ -55,6 +55,7 @@ import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -102,7 +103,7 @@ public class OAuth2TokenRevocationTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
OAuth2RefreshToken token = authorization.getTokens().getRefreshToken(); OAuth2RefreshToken token = authorization.getTokens().getRefreshToken();
TokenType tokenType = TokenType.REFRESH_TOKEN; TokenType tokenType = TokenType.REFRESH_TOKEN;
when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization); when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization);
this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI) this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI)
.params(getTokenRevocationRequestParameters(token, tokenType)) .params(getTokenRevocationRequestParameters(token, tokenType))
@ -111,7 +112,7 @@ public class OAuth2TokenRevocationTests {
.andExpect(status().isOk()); .andExpect(status().isOk());
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType)); verify(authorizationService).findByToken(eq(token.getTokenValue()), isNull());
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(authorizationService).save(authorizationCaptor.capture()); verify(authorizationService).save(authorizationCaptor.capture());
@ -134,7 +135,7 @@ public class OAuth2TokenRevocationTests {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
OAuth2AccessToken token = authorization.getTokens().getAccessToken(); OAuth2AccessToken token = authorization.getTokens().getAccessToken();
TokenType tokenType = TokenType.ACCESS_TOKEN; TokenType tokenType = TokenType.ACCESS_TOKEN;
when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization); when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization);
this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI) this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI)
.params(getTokenRevocationRequestParameters(token, tokenType)) .params(getTokenRevocationRequestParameters(token, tokenType))
@ -143,7 +144,7 @@ public class OAuth2TokenRevocationTests {
.andExpect(status().isOk()); .andExpect(status().isOk());
verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType)); verify(authorizationService).findByToken(eq(token.getTokenValue()), isNull());
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(authorizationService).save(authorizationCaptor.capture()); verify(authorizationService).save(authorizationCaptor.capture());

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2020 the original author or authors. * Copyright 2020-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -35,6 +35,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -115,7 +116,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
TestRegisteredClients.registeredClient2().build()).build(); TestRegisteredClients.registeredClient2().build()).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq("token"), eq("token"),
eq(TokenType.ACCESS_TOKEN))) isNull()))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
@ -136,7 +137,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
registeredClient).build(); registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getRefreshToken().getTokenValue()), eq(authorization.getTokens().getRefreshToken().getTokenValue()),
eq(TokenType.REFRESH_TOKEN))) isNull()))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
@ -164,7 +165,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
registeredClient).build(); registeredClient).build();
when(this.authorizationService.findByToken( when(this.authorizationService.findByToken(
eq(authorization.getTokens().getAccessToken().getTokenValue()), eq(authorization.getTokens().getAccessToken().getTokenValue()),
eq(TokenType.ACCESS_TOKEN))) isNull()))
.thenReturn(authorization); .thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);