mirror of
https://github.com/OpenAPITools/openapi-generator.git
synced 2026-03-23 05:59:14 +00:00
[JAVA][FEIGN] Automatically retry request that fail due to a 401 or 403 (#10021)
* Renew the access token after receiving a 401/403 Feign clients tries to renew the access token after it receives a 401 or 403. It Retries the request 1 time * Add unit test for exhausted retries * Update samples
This commit is contained in:
@@ -19,8 +19,17 @@ import feign.form.FormEncoder;
|
||||
import feign.jackson.JacksonDecoder;
|
||||
import feign.jackson.JacksonEncoder;
|
||||
import feign.slf4j.Slf4jLogger;
|
||||
import org.openapitools.client.auth.*;
|
||||
import org.openapitools.client.auth.HttpBasicAuth;
|
||||
import org.openapitools.client.auth.HttpBearerAuth;
|
||||
import org.openapitools.client.auth.ApiKeyAuth;
|
||||
|
||||
import org.openapitools.client.auth.ApiErrorDecoder;
|
||||
import org.openapitools.client.auth.OAuth;
|
||||
import org.openapitools.client.auth.OAuth.AccessTokenListener;
|
||||
import org.openapitools.client.auth.OAuthFlow;
|
||||
import org.openapitools.client.auth.OauthPasswordGrant;
|
||||
import org.openapitools.client.auth.OauthClientCredentialsGrant;
|
||||
import feign.Retryer;
|
||||
|
||||
@javax.annotation.Generated(value = "org.openapitools.codegen.languages.JavaClientCodegen")
|
||||
public class ApiClient {
|
||||
@@ -40,6 +49,8 @@ public class ApiClient {
|
||||
.client(new OkHttpClient())
|
||||
.encoder(new FormEncoder(new JacksonEncoder(objectMapper)))
|
||||
.decoder(new JacksonDecoder(objectMapper))
|
||||
.errorDecoder(new ApiErrorDecoder())
|
||||
.retryer(new Retryer.Default(0, 0, 2))
|
||||
.logger(new Slf4jLogger());
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package org.openapitools.client.auth;
|
||||
|
||||
import feign.Response;
|
||||
import feign.RetryableException;
|
||||
import feign.codec.ErrorDecoder;
|
||||
|
||||
/**
|
||||
* Error decoder that makes the HTTP 401 and 403 Retryable. Sometimes the 401 or 402 may indicate an expired token
|
||||
* All the other HTTP status are handled by the {@link feign.codec.ErrorDecoder.Default} decoder
|
||||
*/
|
||||
public class ApiErrorDecoder implements ErrorDecoder {
|
||||
|
||||
private final Default defaultErrorDecoder = new Default();
|
||||
|
||||
@Override
|
||||
public Exception decode(String methodKey, Response response) {
|
||||
//401/403 response codes most likely indicate an expired access token, unless it happens two times in a row
|
||||
Exception httpException = defaultErrorDecoder.decode(methodKey, response);
|
||||
if (response.status() == 401 || response.status() == 403) {
|
||||
return new RetryableException(response.status(), "Received status " + response.status() + " trying to renew access token",
|
||||
response.request().httpMethod(), httpException, null, response.request());
|
||||
}
|
||||
return httpException;
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,10 @@ package org.openapitools.client.auth;
|
||||
|
||||
import com.github.scribejava.core.model.OAuth2AccessToken;
|
||||
import com.github.scribejava.core.oauth.OAuth20Service;
|
||||
import feign.Request.HttpMethod;
|
||||
import feign.RequestInterceptor;
|
||||
import feign.RequestTemplate;
|
||||
import feign.RetryableException;
|
||||
|
||||
import java.util.Collection;
|
||||
|
||||
@javax.annotation.Generated(value = "org.openapitools.codegen.languages.JavaClientCodegen")
|
||||
public abstract class OAuth implements RequestInterceptor {
|
||||
@@ -34,25 +34,27 @@ public abstract class OAuth implements RequestInterceptor {
|
||||
@Override
|
||||
public void apply(RequestTemplate template) {
|
||||
// If the request already have an authorization (eg. Basic auth), do nothing
|
||||
if (template.headers().containsKey("Authorization")) {
|
||||
if (requestContainsNonOauthAuthorization(template)) {
|
||||
return;
|
||||
}
|
||||
// If first time, get the token
|
||||
if (expirationTimeMillis == null || System.currentTimeMillis() >= expirationTimeMillis) {
|
||||
updateAccessToken(template);
|
||||
}
|
||||
if (getAccessToken() != null) {
|
||||
template.header("Authorization", "Bearer " + getAccessToken());
|
||||
String accessToken = getAccessToken();
|
||||
if (accessToken != null) {
|
||||
template.header("Authorization", "Bearer " + accessToken);
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized void updateAccessToken(RequestTemplate template) {
|
||||
OAuth2AccessToken accessTokenResponse;
|
||||
try {
|
||||
accessTokenResponse = getOAuth2AccessToken();
|
||||
} catch (Exception e) {
|
||||
throw new RetryableException(0, e.getMessage(), HttpMethod.POST, e, null, template.request());
|
||||
private boolean requestContainsNonOauthAuthorization(RequestTemplate template) {
|
||||
Collection<String> authorizations = template.headers().get("Authorization");
|
||||
if (authorizations == null) {
|
||||
return false;
|
||||
}
|
||||
return !authorizations.stream()
|
||||
.anyMatch(authHeader -> !authHeader.equalsIgnoreCase("Bearer"));
|
||||
}
|
||||
|
||||
private synchronized void updateAccessToken() {
|
||||
OAuth2AccessToken accessTokenResponse;
|
||||
accessTokenResponse = getOAuth2AccessToken();
|
||||
if (accessTokenResponse != null && accessTokenResponse.getAccessToken() != null) {
|
||||
setAccessToken(accessTokenResponse.getAccessToken(), accessTokenResponse.getExpiresIn());
|
||||
if (accessTokenListener != null) {
|
||||
@@ -70,9 +72,18 @@ public abstract class OAuth implements RequestInterceptor {
|
||||
}
|
||||
|
||||
public synchronized String getAccessToken() {
|
||||
// If first time, get the token
|
||||
if (expirationTimeMillis == null || System.currentTimeMillis() >= expirationTimeMillis) {
|
||||
updateAccessToken();
|
||||
}
|
||||
return accessToken;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually sets the access token
|
||||
* @param accessToken The access token
|
||||
* @param expiresIn Seconds until the token expires
|
||||
*/
|
||||
public synchronized void setAccessToken(String accessToken, Integer expiresIn) {
|
||||
this.accessToken = accessToken;
|
||||
this.expirationTimeMillis = expiresIn == null ? null : System.currentTimeMillis() + expiresIn * MILLIS_PER_SECOND;
|
||||
|
||||
Reference in New Issue
Block a user