[BUG]Java] Fix a race condition in RetryingOAuth.mustache (#10087)

If there were multiple concurrent requests at a time at which the OAuth token had expired, only a single request would be retried. The other requests would fail because of the expired token, but not be retried and so the failures would be propagated to the caller.
This commit is contained in:
Gareth Smith
2021-08-04 03:06:39 +01:00
committed by GitHub
parent 4d0a40e982
commit 98e4eb708f
10 changed files with 141 additions and 12 deletions
@@ -156,14 +156,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {
@@ -332,6 +332,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>{{#java8}}1.8{{/java8}}{{^java8}}1.7{{/java8}}</java.version>
@@ -279,6 +279,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>1.7</java.version>
@@ -155,14 +155,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {
@@ -281,6 +281,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>1.7</java.version>
@@ -155,14 +155,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {
@@ -5,6 +5,7 @@ src/test/java/org/openapitools/client/ApiClientTest.java
src/test/java/org/openapitools/client/ConfigurationTest.java
src/test/java/org/openapitools/client/auth/ApiKeyAuthTest.java
src/test/java/org/openapitools/client/auth/HttpBasicAuthTest.java
src/test/java/org/openapitools/client/auth/RetryingOAuthTest.java
src/test/java/org/openapitools/client/model/EnumValueTest.java
src/test/java/org/openapitools/client/model/PetTest.java
src/test/java/org/openapitools/client/model/ArrayOfArrayOfNumberOnlyTest.java
@@ -274,6 +274,12 @@
<version>${junit-version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.11.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<properties>
<java.version>1.7</java.version>
@@ -155,14 +155,12 @@ public class RetryingOAuth extends OAuth implements Interceptor {
oAuthClient.accessToken(tokenRequestBuilder.buildBodyMessage());
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
setAccessToken(accessTokenResponse.getAccessToken());
return !getAccessToken().equals(requestAccessToken);
}
} catch (OAuthSystemException | OAuthProblemException e) {
throw new IOException(e);
}
}
return false;
return getAccessToken() == null || !getAccessToken().equals(requestAccessToken);
}
public TokenRequestBuilder getTokenRequestBuilder() {
@@ -0,0 +1,112 @@
package org.openapitools.client.auth;
import okhttp3.Interceptor.Chain;
import okhttp3.*;
import okhttp3.Response.Builder;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.oltu.oauth2.client.OAuthClient;
import org.apache.oltu.oauth2.client.request.OAuthClientRequest;
import org.apache.oltu.oauth2.client.response.OAuthJSONAccessTokenResponse;
import org.apache.oltu.oauth2.common.exception.OAuthProblemException;
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.junit.Before;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.Collections;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RetryingOAuthTest {
private RetryingOAuth oauth;
@Before
public void setUp() throws Exception {
oauth = new RetryingOAuth("_clientId", "_clientSecret", OAuthFlow.accessCode,
"https://token.example.com", Collections.<String, String>emptyMap());
oauth.setAccessToken("expired-access-token");
FieldUtils.writeField(oauth, "oAuthClient", mockOAuthClient(), true);
}
@Test
public void testSingleRequestUnauthorized() throws Exception {
Response response = oauth.intercept(mockChain());
assertEquals(HttpURLConnection.HTTP_OK, response.code());
}
@Test
public void testTwoConcurrentRequestsUnauthorized() throws Exception {
Callable<Response> callable = new Callable<Response>() {
@Override
public Response call() throws Exception {
return oauth.intercept(mockChain());
}
};
ExecutorService executor = Executors.newFixedThreadPool(2);
try {
Future<Response> response1 = executor.submit(callable);
Future<Response> response2 = executor.submit(callable);
assertEquals(HttpURLConnection.HTTP_OK, response1.get().code());
assertEquals(HttpURLConnection.HTTP_OK, response2.get().code());
} finally {
executor.shutdown();
}
}
private OAuthClient mockOAuthClient() throws OAuthProblemException, OAuthSystemException {
OAuthJSONAccessTokenResponse response = mock(OAuthJSONAccessTokenResponse.class);
when(response.getAccessToken()).thenAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocation) throws Throwable {
// sleep ensures that the bug is triggered.
Thread.sleep(1000);
return "new-access-token";
}
});
OAuthClient client = mock(OAuthClient.class);
when(client.accessToken(any(OAuthClientRequest.class))).thenReturn(response);
return client;
}
private Chain mockChain() throws IOException {
Chain chain = mock(Chain.class);
final Request request = new Request.Builder()
.url("https://api.example.com")
.build();
when(chain.request()).thenReturn(request);
when(chain.proceed(any(Request.class))).thenAnswer(new Answer<Response>() {
@Override
public Response answer(InvocationOnMock inv) {
Request r = inv.getArgument(0);
int responseCode = "Bearer new-access-token".equals(r.header("Authorization"))
? HttpURLConnection.HTTP_OK
: HttpURLConnection.HTTP_UNAUTHORIZED;
return new Builder()
.protocol(Protocol.HTTP_1_0)
.message("sup")
.request(request)
.body(ResponseBody.create(new byte[0], MediaType.get("application/test")))
.code(responseCode)
.build();
}
});
return chain;
}
}