Skip to content

Commit 7add721

Browse files
authored
hanlde the throttling error in the response header (#2442)
* hanlde the throttling error in the response header Signed-off-by: Xun Zhang <xunzh@amazon.com> * address comments and UT Signed-off-by: Xun Zhang <xunzh@amazon.com> * more UTs Signed-off-by: Xun Zhang <xunzh@amazon.com> * add more comments Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent d74c623 commit 7add721

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java

+31
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.util.List;
1717
import java.util.Map;
1818

19+
import org.apache.commons.collections.MapUtils;
1920
import org.apache.http.HttpStatus;
2021
import org.apache.logging.log4j.util.Strings;
2122
import org.opensearch.OpenSearchStatusException;
@@ -38,6 +39,7 @@
3839

3940
@Log4j2
4041
public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandler {
42+
public static final String AMZ_ERROR_HEADER = "x-amzn-ErrorType";
4143
@Getter
4244
private Integer statusCode;
4345
@Getter
@@ -80,6 +82,10 @@ public void onHeaders(SdkHttpResponse response) {
8082
SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response;
8183
log.debug("received response headers: " + sdkResponse.headers());
8284
this.statusCode = sdkResponse.statusCode();
85+
if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) {
86+
handleThrottlingInHeader(sdkResponse);
87+
// add more handling here for other exceptions in headers
88+
}
8389
}
8490

8591
@Override
@@ -95,6 +101,31 @@ public void onError(Throwable error) {
95101
actionListener.onFailure(new OpenSearchStatusException(errorMessage, status));
96102
}
97103

104+
private void handleThrottlingInHeader(SdkHttpFullResponse sdkResponse) {
105+
if (MapUtils.isEmpty(sdkResponse.headers())) {
106+
return;
107+
}
108+
List<String> errorsInHeader = sdkResponse.headers().get(AMZ_ERROR_HEADER);
109+
if (errorsInHeader == null || errorsInHeader.isEmpty()) {
110+
return;
111+
}
112+
// Check the throttling exception from AMZN servers, e.g. sageMaker.
113+
// See [https://github.com/opensearch-project/ml-commons/issues/2429] for more details.
114+
boolean containsThrottlingException = errorsInHeader.stream().anyMatch(str -> str.startsWith("ThrottlingException"));
115+
if (containsThrottlingException && executionContext.getExceptionHolder().get() == null) {
116+
log.error("Remote server returned error code: {}", statusCode);
117+
executionContext
118+
.getExceptionHolder()
119+
.compareAndSet(
120+
null,
121+
new OpenSearchStatusException(
122+
REMOTE_SERVICE_ERROR + "The request was denied due to remote server throttling.",
123+
RestStatus.fromCode(statusCode)
124+
)
125+
);
126+
}
127+
}
128+
98129
private void processResponse(
99130
Integer statusCode,
100131
String body,

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java

+176
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import static org.mockito.Mockito.times;
1212
import static org.mockito.Mockito.verify;
1313
import static org.mockito.Mockito.when;
14+
import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
15+
import static org.opensearch.ml.engine.algorithms.remote.MLSdkAsyncHttpResponseHandler.AMZ_ERROR_HEADER;
1416

1517
import java.nio.ByteBuffer;
1618
import java.util.Arrays;
@@ -51,6 +53,8 @@ public class MLSdkAsyncHttpResponseHandlerTest {
5153

5254
private Connector noProcessFunctionConnector;
5355

56+
private Map<String, List<String>> headersMap;
57+
5458
@Mock
5559
private SdkHttpFullResponse sdkHttpResponse;
5660
@Mock
@@ -104,6 +108,7 @@ public void setup() {
104108
null
105109
);
106110
responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber();
111+
headersMap = Map.of(AMZ_ERROR_HEADER, Arrays.asList("ThrottlingException:request throttled!"));
107112
}
108113

109114
@Test
@@ -112,6 +117,13 @@ public void test_OnHeaders() {
112117
assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 200;
113118
}
114119

120+
@Test
121+
public void test_OnHeaders_withError() {
122+
when(sdkHttpResponse.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
123+
mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
124+
assert mlSdkAsyncHttpResponseHandler.getStatusCode() == 400;
125+
}
126+
115127
@Test
116128
public void test_OnStream_with_postProcessFunction_bedRock() {
117129
String response = "{\n"
@@ -419,4 +431,168 @@ public void test_onComplete_error_http_status() {
419431
System.out.println(captor.getValue().getMessage());
420432
assert captor.getValue().getMessage().contains("runtime error");
421433
}
434+
435+
@Test
436+
public void test_onComplete_throttle_error_headers() {
437+
String error = "{\"message\": null}";
438+
SdkHttpResponse response = mock(SdkHttpFullResponse.class);
439+
when(response.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
440+
when(response.headers()).thenReturn(headersMap);
441+
mlSdkAsyncHttpResponseHandler.onHeaders(response);
442+
Publisher<ByteBuffer> stream = s -> {
443+
try {
444+
s.onSubscribe(mock(Subscription.class));
445+
s.onNext(ByteBuffer.wrap(error.getBytes()));
446+
s.onComplete();
447+
} catch (Throwable e) {
448+
s.onError(e);
449+
}
450+
};
451+
mlSdkAsyncHttpResponseHandler.onStream(stream);
452+
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
453+
verify(actionListener, times(1)).onFailure(captor.capture());
454+
assert captor.getValue() instanceof OpenSearchStatusException;
455+
System.out.println(captor.getValue().getMessage());
456+
assert captor.getValue().getMessage().contains(REMOTE_SERVICE_ERROR);
457+
}
458+
459+
@Test
460+
public void test_onComplete_throttle_exceptionFirst() {
461+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
462+
String response1 = "{\n"
463+
+ " \"embedding\": [\n"
464+
+ " 0.46484375,\n"
465+
+ " -0.017822266,\n"
466+
+ " 0.17382812,\n"
467+
+ " 0.10595703,\n"
468+
+ " 0.875,\n"
469+
+ " 0.19140625,\n"
470+
+ " -0.36914062,\n"
471+
+ " -0.0011978149\n"
472+
+ " ]\n"
473+
+ "}";
474+
String response2 = "{\"message\": null}";
475+
CountDownLatch count = new CountDownLatch(2);
476+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
477+
new ExecutionContext(0, count, exceptionHolder),
478+
actionListener,
479+
parameters,
480+
tensorOutputs,
481+
connector,
482+
scriptService,
483+
null
484+
);
485+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
486+
new ExecutionContext(1, count, exceptionHolder),
487+
actionListener,
488+
parameters,
489+
tensorOutputs,
490+
connector,
491+
scriptService,
492+
null
493+
);
494+
495+
SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
496+
when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
497+
when(sdkHttpResponse2.headers()).thenReturn(headersMap);
498+
mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
499+
Publisher<ByteBuffer> stream2 = s -> {
500+
try {
501+
s.onSubscribe(mock(Subscription.class));
502+
s.onNext(ByteBuffer.wrap(response2.getBytes()));
503+
s.onComplete();
504+
} catch (Throwable e) {
505+
s.onError(e);
506+
}
507+
};
508+
mlSdkAsyncHttpResponseHandler2.onStream(stream2);
509+
510+
SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
511+
when(sdkHttpResponse1.statusCode()).thenReturn(200);
512+
mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
513+
Publisher<ByteBuffer> stream1 = s -> {
514+
try {
515+
s.onSubscribe(mock(Subscription.class));
516+
s.onNext(ByteBuffer.wrap(response1.getBytes()));
517+
s.onComplete();
518+
} catch (Throwable e) {
519+
s.onError(e);
520+
}
521+
};
522+
mlSdkAsyncHttpResponseHandler1.onStream(stream1);
523+
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
524+
verify(actionListener, times(1)).onFailure(captor.capture());
525+
assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling.");
526+
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
527+
}
528+
529+
@Test
530+
public void test_onComplete_throttle_exceptionSecond() {
531+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
532+
String response1 = "{\n"
533+
+ " \"embedding\": [\n"
534+
+ " 0.46484375,\n"
535+
+ " -0.017822266,\n"
536+
+ " 0.17382812,\n"
537+
+ " 0.10595703,\n"
538+
+ " 0.875,\n"
539+
+ " 0.19140625,\n"
540+
+ " -0.36914062,\n"
541+
+ " -0.0011978149\n"
542+
+ " ]\n"
543+
+ "}";
544+
String response2 = "{\"message\": null}";
545+
CountDownLatch count = new CountDownLatch(2);
546+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
547+
new ExecutionContext(0, count, exceptionHolder),
548+
actionListener,
549+
parameters,
550+
tensorOutputs,
551+
connector,
552+
scriptService,
553+
null
554+
);
555+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
556+
new ExecutionContext(1, count, exceptionHolder),
557+
actionListener,
558+
parameters,
559+
tensorOutputs,
560+
connector,
561+
scriptService,
562+
null
563+
);
564+
SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
565+
when(sdkHttpResponse1.statusCode()).thenReturn(200);
566+
mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
567+
Publisher<ByteBuffer> stream1 = s -> {
568+
try {
569+
s.onSubscribe(mock(Subscription.class));
570+
s.onNext(ByteBuffer.wrap(response1.getBytes()));
571+
s.onComplete();
572+
} catch (Throwable e) {
573+
s.onError(e);
574+
}
575+
};
576+
mlSdkAsyncHttpResponseHandler1.onStream(stream1);
577+
578+
SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
579+
when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
580+
when(sdkHttpResponse2.headers()).thenReturn(headersMap);
581+
mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
582+
Publisher<ByteBuffer> stream2 = s -> {
583+
try {
584+
s.onSubscribe(mock(Subscription.class));
585+
s.onNext(ByteBuffer.wrap(response2.getBytes()));
586+
s.onComplete();
587+
} catch (Throwable e) {
588+
s.onError(e);
589+
}
590+
};
591+
mlSdkAsyncHttpResponseHandler2.onStream(stream2);
592+
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
593+
verify(actionListener, times(1)).onFailure(captor.capture());
594+
assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling.");
595+
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
596+
}
597+
422598
}

0 commit comments

Comments
 (0)