Skip to content

Commit 4a726f2

Browse files
committed
more UTs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 35d2ea4 commit 4a726f2

File tree

1 file changed

+143
-1
lines changed

1 file changed

+143
-1
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java

+143-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ public class MLSdkAsyncHttpResponseHandlerTest {
5353

5454
private Connector noProcessFunctionConnector;
5555

56+
private Map<String, List<String>> headersMap;
57+
5658
@Mock
5759
private SdkHttpFullResponse sdkHttpResponse;
5860
@Mock
@@ -106,6 +108,7 @@ public void setup() {
106108
null
107109
);
108110
responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber();
111+
headersMap = Map.of(AMZ_ERROR_HEADER, Arrays.asList("ThrottlingException:request throttled!"));
109112
}
110113

111114
@Test
@@ -434,7 +437,6 @@ public void test_onComplete_throttle_error_headers() {
434437
String error = "{\"message\": null}";
435438
SdkHttpResponse response = mock(SdkHttpFullResponse.class);
436439
when(response.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
437-
Map<String, List<String>> headersMap = Map.of(AMZ_ERROR_HEADER, Arrays.asList("ThrottlingException:request throttled!"));
438440
when(response.headers()).thenReturn(headersMap);
439441
mlSdkAsyncHttpResponseHandler.onHeaders(response);
440442
Publisher<ByteBuffer> stream = s -> {
@@ -453,4 +455,144 @@ public void test_onComplete_throttle_error_headers() {
453455
System.out.println(captor.getValue().getMessage());
454456
assert captor.getValue().getMessage().contains(REMOTE_SERVICE_ERROR);
455457
}
458+
459+
@Test
460+
public void test_onComplete_throttle_exceptionFirst() {
461+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
462+
String response1 = "{\n"
463+
+ " \"embedding\": [\n"
464+
+ " 0.46484375,\n"
465+
+ " -0.017822266,\n"
466+
+ " 0.17382812,\n"
467+
+ " 0.10595703,\n"
468+
+ " 0.875,\n"
469+
+ " 0.19140625,\n"
470+
+ " -0.36914062,\n"
471+
+ " -0.0011978149\n"
472+
+ " ]\n"
473+
+ "}";
474+
String response2 = "{\"message\": null}";
475+
CountDownLatch count = new CountDownLatch(2);
476+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
477+
new ExecutionContext(0, count, exceptionHolder),
478+
actionListener,
479+
parameters,
480+
tensorOutputs,
481+
connector,
482+
scriptService,
483+
null
484+
);
485+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
486+
new ExecutionContext(1, count, exceptionHolder),
487+
actionListener,
488+
parameters,
489+
tensorOutputs,
490+
connector,
491+
scriptService,
492+
null
493+
);
494+
495+
SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
496+
when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
497+
when(sdkHttpResponse2.headers()).thenReturn(headersMap);
498+
mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
499+
Publisher<ByteBuffer> stream2 = s -> {
500+
try {
501+
s.onSubscribe(mock(Subscription.class));
502+
s.onNext(ByteBuffer.wrap(response2.getBytes()));
503+
s.onComplete();
504+
} catch (Throwable e) {
505+
s.onError(e);
506+
}
507+
};
508+
mlSdkAsyncHttpResponseHandler2.onStream(stream2);
509+
510+
SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
511+
when(sdkHttpResponse1.statusCode()).thenReturn(200);
512+
mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
513+
Publisher<ByteBuffer> stream1 = s -> {
514+
try {
515+
s.onSubscribe(mock(Subscription.class));
516+
s.onNext(ByteBuffer.wrap(response1.getBytes()));
517+
s.onComplete();
518+
} catch (Throwable e) {
519+
s.onError(e);
520+
}
521+
};
522+
mlSdkAsyncHttpResponseHandler1.onStream(stream1);
523+
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
524+
verify(actionListener, times(1)).onFailure(captor.capture());
525+
assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling.");
526+
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
527+
}
528+
529+
@Test
530+
public void test_onComplete_throttle_exceptionSecond() {
531+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
532+
String response1 = "{\n"
533+
+ " \"embedding\": [\n"
534+
+ " 0.46484375,\n"
535+
+ " -0.017822266,\n"
536+
+ " 0.17382812,\n"
537+
+ " 0.10595703,\n"
538+
+ " 0.875,\n"
539+
+ " 0.19140625,\n"
540+
+ " -0.36914062,\n"
541+
+ " -0.0011978149\n"
542+
+ " ]\n"
543+
+ "}";
544+
String response2 = "{\"message\": null}";
545+
CountDownLatch count = new CountDownLatch(2);
546+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
547+
new ExecutionContext(0, count, exceptionHolder),
548+
actionListener,
549+
parameters,
550+
tensorOutputs,
551+
connector,
552+
scriptService,
553+
null
554+
);
555+
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
556+
new ExecutionContext(1, count, exceptionHolder),
557+
actionListener,
558+
parameters,
559+
tensorOutputs,
560+
connector,
561+
scriptService,
562+
null
563+
);
564+
SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
565+
when(sdkHttpResponse1.statusCode()).thenReturn(200);
566+
mlSdkAsyncHttpResponseHandler1.onHeaders(sdkHttpResponse1);
567+
Publisher<ByteBuffer> stream1 = s -> {
568+
try {
569+
s.onSubscribe(mock(Subscription.class));
570+
s.onNext(ByteBuffer.wrap(response1.getBytes()));
571+
s.onComplete();
572+
} catch (Throwable e) {
573+
s.onError(e);
574+
}
575+
};
576+
mlSdkAsyncHttpResponseHandler1.onStream(stream1);
577+
578+
SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
579+
when(sdkHttpResponse2.statusCode()).thenReturn(HttpStatusCode.BAD_REQUEST);
580+
when(sdkHttpResponse2.headers()).thenReturn(headersMap);
581+
mlSdkAsyncHttpResponseHandler2.onHeaders(sdkHttpResponse2);
582+
Publisher<ByteBuffer> stream2 = s -> {
583+
try {
584+
s.onSubscribe(mock(Subscription.class));
585+
s.onNext(ByteBuffer.wrap(response2.getBytes()));
586+
s.onComplete();
587+
} catch (Throwable e) {
588+
s.onError(e);
589+
}
590+
};
591+
mlSdkAsyncHttpResponseHandler2.onStream(stream2);
592+
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
593+
verify(actionListener, times(1)).onFailure(captor.capture());
594+
assert captor.getValue().getMessage().equals("Error from remote service: The request was denied due to remote server throttling.");
595+
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
596+
}
597+
456598
}

0 commit comments

Comments
 (0)