[BUG]Java] Fix a race condition in RetryingOAuth.mustache (#10087)

If there were multiple concurrent requests at a time at which the OAuth token had expired, only a single request would be retried. The other requests would fail because of the expired token, but not be retried and so the failures would be propagated to the caller.
This commit is contained in:
Gareth Smith
2021-08-04 03:06:39 +01:00
committed by GitHub
parent 4d0a40e982
commit 98e4eb708f
10 changed files with 141 additions and 12 deletions

View File

@@ -156,14 +156,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {

View File

@@ -332,6 +332,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>{{#java8}}1.8{{/java8}}{{^java8}}1.7{{/java8}}</java.version>

View File

@@ -279,6 +279,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>1.7</java.version>

View File

@@ -155,14 +155,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {

View File

@@ -281,6 +281,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>1.7</java.version>

View File

@@ -155,14 +155,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {

View File

@@ -5,6 +5,7 @@ src/test/java/org/openapitools/client/ApiClientTest.java
src/test/java/org/openapitools/client/ConfigurationTest.java
src/test/java/org/openapitools/client/auth/ApiKeyAuthTest.java
src/test/java/org/openapitools/client/auth/HttpBasicAuthTest.java
src/test/java/org/openapitools/client/auth/RetryingOAuthTest.java
src/test/java/org/openapitools/client/model/EnumValueTest.java
src/test/java/org/openapitools/client/model/PetTest.java
src/test/java/org/openapitools/client/model/ArrayOfArrayOfNumberOnlyTest.java

View File

@@ -274,6 +274,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>1.7</java.version>

View File

@@ -155,14 +155,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {

View File

@@ -0,0 +1,112 @@
package org.openapitools.client.auth;
import okhttp3.Interceptor.Chain;
import okhttp3.*;
import okhttp3.Response.Builder;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.oltu.oauth2.client.OAuthClient;
import org.apache.oltu.oauth2.client.request.OAuthClientRequest;
import org.apache.oltu.oauth2.client.response.OAuthJSONAccessTokenResponse;
import org.apache.oltu.oauth2.common.exception.OAuthProblemException;
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.junit.Before;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.Collections;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RetryingOAuthTest {
private RetryingOAuth oauth;
@Before
public void setUp() throws Exception {
oauth = new RetryingOAuth("_clientId", "_clientSecret", OAuthFlow.accessCode,
"https://token.example.com", Collections.<String, String>emptyMap());
oauth.setAccessToken("expired-access-token");
FieldUtils.writeField(oauth, "oAuthClient", mockOAuthClient(), true);
}
@Test
public void testSingleRequestUnauthorized() throws Exception {
Response response = oauth.intercept(mockChain());
assertEquals(HttpURLConnection.HTTP_OK, response.code());
}
@Test
public void testTwoConcurrentRequestsUnauthorized() throws Exception {
Callable<Response> callable = new Callable<Response>() {
@Override
public Response call() throws Exception {
return oauth.intercept(mockChain());
}
};
ExecutorService executor = Executors.newFixedThreadPool(2);
try {
Future<Response> response1 = executor.submit(callable);
Future<Response> response2 = executor.submit(callable);
assertEquals(HttpURLConnection.HTTP_OK, response1.get().code());
assertEquals(HttpURLConnection.HTTP_OK, response2.get().code());
} finally {
executor.shutdown();
}
}
private OAuthClient mockOAuthClient() throws OAuthProblemException, OAuthSystemException {
OAuthJSONAccessTokenResponse response = mock(OAuthJSONAccessTokenResponse.class);
when(response.getAccessToken()).thenAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocation) throws Throwable {
// sleep ensures that the bug is triggered.
Thread.sleep(1000);
return "new-access-token";
}
});
OAuthClient client = mock(OAuthClient.class);
when(client.accessToken(any(OAuthClientRequest.class))).thenReturn(response);
return client;
}
private Chain mockChain() throws IOException {
Chain chain = mock(Chain.class);
final Request request = new Request.Builder()
.url("https://api.example.com")
.build();
when(chain.request()).thenReturn(request);
when(chain.proceed(any(Request.class))).thenAnswer(new Answer<Response>() {
@Override
public Response answer(InvocationOnMock inv) {
Request r = inv.getArgument(0);
int responseCode = "Bearer new-access-token".equals(r.header("Authorization"))
? HttpURLConnection.HTTP_OK
: HttpURLConnection.HTTP_UNAUTHORIZED;
return new Builder()
.protocol(Protocol.HTTP_1_0)
.message("sup")
.request(request)
.body(ResponseBody.create(new byte[0], MediaType.get("application/test")))
.code(responseCode)
.build();
}
});
return chain;
}
}