Skip to content

Commit bd390d2

Browse files
committed
[NOID] Fixes #4153: Handling OpenAI 429's gracefully (#4284)
* Fixes #4153: Handling OpenAI 429's gracefully * cleanup * fix tests
1 parent 25131b2 commit bd390d2

File tree

4 files changed

+185
-1
lines changed

4 files changed

+185
-1
lines changed

docs/asciidoc/modules/ROOT/pages/ml/openai.adoc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ If present, they take precedence over the analogous APOC configs.
2828
| endpoint | analogous to `apoc.ml.openai.url` APOC config
2929
| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config
3030
| failOnError | If true (default), the procedure fails in case of empty, blank or null input
31+
| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false)
32+
| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5)
33+
| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false)
3134
|===
3235

3336

full/src/main/java/apoc/ml/OpenAI.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import apoc.ApocConfig;
1313
import apoc.Extended;
1414
import apoc.result.MapResult;
15+
import apoc.util.ExtendedUtil;
1516
import apoc.util.JsonUtil;
1617
import apoc.util.Util;
1718
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -35,7 +36,11 @@
3536

3637
@Extended
3738
public class OpenAI {
39+
public static final String JSON_PATH_CONF_KEY = "jsonPath";
3840
public static final String FAIL_ON_ERROR_CONF = "failOnError";
41+
public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
42+
public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
43+
public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries";
3944

4045
@Context
4146
public ApocConfig apocConfig;
@@ -63,6 +68,9 @@ static Stream<Object> executeRequest(
6368
ApocConfig apocConfig)
6469
throws JsonProcessingException, MalformedURLException {
6570
apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
71+
boolean enableBackOffRetries = Util.toBoolean(configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY));
72+
Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5));
73+
boolean exponentialBackoff = Util.toBoolean(configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY));
6674
if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty");
6775
String apiTypeString = (String) configuration.getOrDefault(
6876
API_TYPE_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name()));
@@ -83,6 +91,7 @@ static Stream<Object> executeRequest(
8391
OpenAIRequestHandler apiType = type.get();
8492

8593
final Map<String, Object> headers = new HashMap<>();
94+
String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
8695
headers.put("Content-Type", "application/json");
8796

8897
apiType.addApiKey(headers, apiKey);
@@ -93,7 +102,14 @@ static Stream<Object> executeRequest(
93102
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
94103
// therefore is better to join the not-empty path pieces
95104
var url = apiType.getFullUrl(path, configuration, apocConfig);
96-
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of());
105+
return ExtendedUtil.withBackOffRetries(
106+
() -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of()),
107+
enableBackOffRetries,
108+
backOffRetries,
109+
exponentialBackoff,
110+
exception -> {
111+
if (!exception.getMessage().contains("429")) throw new RuntimeException(exception);
112+
});
97113
}
98114

99115
@Procedure("apoc.ml.openai.embedding")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package apoc.util;
2+
3+
import java.time.Duration;
4+
import java.util.function.Consumer;
5+
import java.util.function.Supplier;
6+
7+
public class ExtendedUtil {
8+
public static <T> T withBackOffRetries(
9+
Supplier<T> func,
10+
boolean retry,
11+
int backoffRetry,
12+
boolean exponential,
13+
Consumer<Exception> exceptionHandler) {
14+
T result;
15+
backoffRetry = backoffRetry < 1 ? 5 : backoffRetry;
16+
int countDown = backoffRetry;
17+
exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {});
18+
while (true) {
19+
try {
20+
result = func.get();
21+
break;
22+
} catch (Exception e) {
23+
if (!retry || countDown < 1) throw e;
24+
exceptionHandler.accept(e);
25+
countDown--;
26+
long delay = getDelay(backoffRetry, countDown, exponential);
27+
backoffSleep(delay);
28+
}
29+
}
30+
return result;
31+
}
32+
33+
private static void backoffSleep(long millis) {
34+
sleep(millis, "Operation interrupted during backoff");
35+
}
36+
37+
public static void sleep(long millis, String interruptedMessage) {
38+
try {
39+
Thread.sleep(millis);
40+
} catch (InterruptedException ie) {
41+
Thread.currentThread().interrupt();
42+
throw new RuntimeException(interruptedMessage, ie);
43+
}
44+
}
45+
46+
private static long getDelay(int backoffRetry, int countDown, boolean exponential) {
47+
int backOffTime = backoffRetry - countDown;
48+
long sleepMultiplier = exponential
49+
? (long) Math.pow(2, backOffTime)
50+
: // Exponential retry progression
51+
backOffTime; // Linear retry progression
52+
return Math.min(
53+
Duration.ofSeconds(1).multipliedBy(sleepMultiplier).toMillis(),
54+
Duration.ofSeconds(30).toMillis() // Max 30s
55+
);
56+
}
57+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package apoc.util;
2+
3+
import static org.junit.Assert.*;
4+
import static org.junit.Assert.assertTrue;
5+
6+
import org.junit.Test;
7+
8+
public class ExtendedUtilTest {
9+
10+
private static int i = 0;
11+
12+
@Test
13+
public void testWithLinearBackOffRetriesWithSuccess() {
14+
i = 0;
15+
long start = System.currentTimeMillis();
16+
int result = ExtendedUtil.withBackOffRetries(
17+
this::testFunction,
18+
true,
19+
-1, // test backoffRetry default value -> 5
20+
false,
21+
runEx -> {
22+
if (!runEx.getMessage().contains("Expected")) throw new RuntimeException("Some Bad News...");
23+
});
24+
long time = System.currentTimeMillis() - start;
25+
26+
assertEquals(4, result);
27+
28+
// The method will attempt to execute the operation with a linear backoff strategy,
29+
// sleeping for 1 second, 2 seconds, and 3 seconds between retries.
30+
// This results in a total wait time of 6 seconds (1s + 2s + 3s + 4s) if the operation succeeds on the third
31+
// attempt,
32+
// leading to an approximate execution time of 6 seconds.
33+
assertTrue("Current time is: " + time, time > 9000 && time < 11000);
34+
}
35+
36+
@Test
37+
public void testWithExponentialBackOffRetriesWithSuccess() {
38+
i = 0;
39+
long start = System.currentTimeMillis();
40+
int result = ExtendedUtil.withBackOffRetries(
41+
this::testFunction,
42+
true,
43+
0, // test backoffRetry default value -> 5
44+
true,
45+
runEx -> {});
46+
long time = System.currentTimeMillis() - start;
47+
48+
assertEquals(4, result);
49+
50+
// The method will attempt to execute the operation with an exponential backoff strategy,
51+
// sleeping for 2 second, 4 seconds, and 8 seconds between retries.
52+
// This results in a total wait time of 30 seconds (2s + 4s + 8s + 16s) if the operation succeeds on the third
53+
// attempt,
54+
// leading to an approximate execution time of 14 seconds.
55+
assertTrue("Current time is: " + time, time > 29000 && time < 31000);
56+
}
57+
58+
@Test
59+
public void testBackOffRetriesWithError() {
60+
i = 0;
61+
long start = System.currentTimeMillis();
62+
assertThrows(
63+
RuntimeException.class,
64+
() -> ExtendedUtil.withBackOffRetries(this::testFunction, true, 2, false, runEx -> {}));
65+
long time = System.currentTimeMillis() - start;
66+
67+
// The method is configured to retry the operation twice.
68+
// So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an
69+
// exception.
70+
// Resulting in an approximate execution time of 3 seconds.
71+
assertTrue("Current time is: " + time, time > 2000 && time < 4000);
72+
}
73+
74+
@Test
75+
public void testBackOffRetriesWithErrorAndExponential() {
76+
i = 0;
77+
long start = System.currentTimeMillis();
78+
assertThrows(
79+
RuntimeException.class,
80+
() -> ExtendedUtil.withBackOffRetries(this::testFunction, true, 2, true, runEx -> {}));
81+
long time = System.currentTimeMillis() - start;
82+
83+
// The method is configured to retry the operation twice.
84+
// So, it will make two extra-attempts, waiting for 2 second and 4 seconds before failing and throwing an
85+
// exception.
86+
// Resulting in an approximate execution time of 6 seconds.
87+
assertTrue("Current time is: " + time, time > 5000 && time < 7000);
88+
}
89+
90+
@Test
91+
public void testWithoutBackOffRetriesWithError() {
92+
i = 0;
93+
assertThrows(
94+
RuntimeException.class,
95+
() -> ExtendedUtil.withBackOffRetries(this::testFunction, false, 30, false, runEx -> {}));
96+
97+
// Retry strategy is not active and the testFunction is executed only once by raising an exception.
98+
assertEquals(1, i);
99+
}
100+
101+
private int testFunction() {
102+
if (i == 4) {
103+
return i;
104+
}
105+
i++;
106+
throw new RuntimeException("Expected i not equal to 4");
107+
}
108+
}

0 commit comments

Comments
 (0)