11
11
import static org .mockito .Mockito .times ;
12
12
import static org .mockito .Mockito .verify ;
13
13
import static org .mockito .Mockito .when ;
14
+ import static org .opensearch .ml .common .CommonValue .REMOTE_SERVICE_ERROR ;
15
+ import static org .opensearch .ml .engine .algorithms .remote .MLSdkAsyncHttpResponseHandler .AMZ_ERROR_HEADER ;
14
16
15
17
import java .nio .ByteBuffer ;
16
18
import java .util .Arrays ;
@@ -51,6 +53,8 @@ public class MLSdkAsyncHttpResponseHandlerTest {
51
53
52
54
private Connector noProcessFunctionConnector ;
53
55
56
+ private Map <String , List <String >> headersMap ;
57
+
54
58
@ Mock
55
59
private SdkHttpFullResponse sdkHttpResponse ;
56
60
@ Mock
@@ -104,6 +108,7 @@ public void setup() {
104
108
null
105
109
);
106
110
responseSubscriber = mlSdkAsyncHttpResponseHandler .new MLResponseSubscriber ();
111
+ headersMap = Map .of (AMZ_ERROR_HEADER , Arrays .asList ("ThrottlingException:request throttled!" ));
107
112
}
108
113
109
114
@ Test
@@ -112,6 +117,13 @@ public void test_OnHeaders() {
112
117
assert mlSdkAsyncHttpResponseHandler .getStatusCode () == 200 ;
113
118
}
114
119
120
+ @ Test
121
+ public void test_OnHeaders_withError () {
122
+ when (sdkHttpResponse .statusCode ()).thenReturn (HttpStatusCode .BAD_REQUEST );
123
+ mlSdkAsyncHttpResponseHandler .onHeaders (sdkHttpResponse );
124
+ assert mlSdkAsyncHttpResponseHandler .getStatusCode () == 400 ;
125
+ }
126
+
115
127
@ Test
116
128
public void test_OnStream_with_postProcessFunction_bedRock () {
117
129
String response = "{\n "
@@ -419,4 +431,168 @@ public void test_onComplete_error_http_status() {
419
431
System .out .println (captor .getValue ().getMessage ());
420
432
assert captor .getValue ().getMessage ().contains ("runtime error" );
421
433
}
434
+
435
+ @ Test
436
+ public void test_onComplete_throttle_error_headers () {
437
+ String error = "{\" message\" : null}" ;
438
+ SdkHttpResponse response = mock (SdkHttpFullResponse .class );
439
+ when (response .statusCode ()).thenReturn (HttpStatusCode .BAD_REQUEST );
440
+ when (response .headers ()).thenReturn (headersMap );
441
+ mlSdkAsyncHttpResponseHandler .onHeaders (response );
442
+ Publisher <ByteBuffer > stream = s -> {
443
+ try {
444
+ s .onSubscribe (mock (Subscription .class ));
445
+ s .onNext (ByteBuffer .wrap (error .getBytes ()));
446
+ s .onComplete ();
447
+ } catch (Throwable e ) {
448
+ s .onError (e );
449
+ }
450
+ };
451
+ mlSdkAsyncHttpResponseHandler .onStream (stream );
452
+ ArgumentCaptor <Exception > captor = ArgumentCaptor .forClass (Exception .class );
453
+ verify (actionListener , times (1 )).onFailure (captor .capture ());
454
+ assert captor .getValue () instanceof OpenSearchStatusException ;
455
+ System .out .println (captor .getValue ().getMessage ());
456
+ assert captor .getValue ().getMessage ().contains (REMOTE_SERVICE_ERROR );
457
+ }
458
+
459
+ @ Test
460
+ public void test_onComplete_throttle_exceptionFirst () {
461
+ AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
462
+ String response1 = "{\n "
463
+ + " \" embedding\" : [\n "
464
+ + " 0.46484375,\n "
465
+ + " -0.017822266,\n "
466
+ + " 0.17382812,\n "
467
+ + " 0.10595703,\n "
468
+ + " 0.875,\n "
469
+ + " 0.19140625,\n "
470
+ + " -0.36914062,\n "
471
+ + " -0.0011978149\n "
472
+ + " ]\n "
473
+ + "}" ;
474
+ String response2 = "{\" message\" : null}" ;
475
+ CountDownLatch count = new CountDownLatch (2 );
476
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler (
477
+ new ExecutionContext (0 , count , exceptionHolder ),
478
+ actionListener ,
479
+ parameters ,
480
+ tensorOutputs ,
481
+ connector ,
482
+ scriptService ,
483
+ null
484
+ );
485
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler (
486
+ new ExecutionContext (1 , count , exceptionHolder ),
487
+ actionListener ,
488
+ parameters ,
489
+ tensorOutputs ,
490
+ connector ,
491
+ scriptService ,
492
+ null
493
+ );
494
+
495
+ SdkHttpFullResponse sdkHttpResponse2 = mock (SdkHttpFullResponse .class );
496
+ when (sdkHttpResponse2 .statusCode ()).thenReturn (HttpStatusCode .BAD_REQUEST );
497
+ when (sdkHttpResponse2 .headers ()).thenReturn (headersMap );
498
+ mlSdkAsyncHttpResponseHandler2 .onHeaders (sdkHttpResponse2 );
499
+ Publisher <ByteBuffer > stream2 = s -> {
500
+ try {
501
+ s .onSubscribe (mock (Subscription .class ));
502
+ s .onNext (ByteBuffer .wrap (response2 .getBytes ()));
503
+ s .onComplete ();
504
+ } catch (Throwable e ) {
505
+ s .onError (e );
506
+ }
507
+ };
508
+ mlSdkAsyncHttpResponseHandler2 .onStream (stream2 );
509
+
510
+ SdkHttpFullResponse sdkHttpResponse1 = mock (SdkHttpFullResponse .class );
511
+ when (sdkHttpResponse1 .statusCode ()).thenReturn (200 );
512
+ mlSdkAsyncHttpResponseHandler1 .onHeaders (sdkHttpResponse1 );
513
+ Publisher <ByteBuffer > stream1 = s -> {
514
+ try {
515
+ s .onSubscribe (mock (Subscription .class ));
516
+ s .onNext (ByteBuffer .wrap (response1 .getBytes ()));
517
+ s .onComplete ();
518
+ } catch (Throwable e ) {
519
+ s .onError (e );
520
+ }
521
+ };
522
+ mlSdkAsyncHttpResponseHandler1 .onStream (stream1 );
523
+ ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
524
+ verify (actionListener , times (1 )).onFailure (captor .capture ());
525
+ assert captor .getValue ().getMessage ().equals ("Error from remote service: The request was denied due to remote server throttling." );
526
+ assert captor .getValue ().status ().getStatus () == HttpStatusCode .BAD_REQUEST ;
527
+ }
528
+
529
+ @ Test
530
+ public void test_onComplete_throttle_exceptionSecond () {
531
+ AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
532
+ String response1 = "{\n "
533
+ + " \" embedding\" : [\n "
534
+ + " 0.46484375,\n "
535
+ + " -0.017822266,\n "
536
+ + " 0.17382812,\n "
537
+ + " 0.10595703,\n "
538
+ + " 0.875,\n "
539
+ + " 0.19140625,\n "
540
+ + " -0.36914062,\n "
541
+ + " -0.0011978149\n "
542
+ + " ]\n "
543
+ + "}" ;
544
+ String response2 = "{\" message\" : null}" ;
545
+ CountDownLatch count = new CountDownLatch (2 );
546
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler (
547
+ new ExecutionContext (0 , count , exceptionHolder ),
548
+ actionListener ,
549
+ parameters ,
550
+ tensorOutputs ,
551
+ connector ,
552
+ scriptService ,
553
+ null
554
+ );
555
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler (
556
+ new ExecutionContext (1 , count , exceptionHolder ),
557
+ actionListener ,
558
+ parameters ,
559
+ tensorOutputs ,
560
+ connector ,
561
+ scriptService ,
562
+ null
563
+ );
564
+ SdkHttpFullResponse sdkHttpResponse1 = mock (SdkHttpFullResponse .class );
565
+ when (sdkHttpResponse1 .statusCode ()).thenReturn (200 );
566
+ mlSdkAsyncHttpResponseHandler1 .onHeaders (sdkHttpResponse1 );
567
+ Publisher <ByteBuffer > stream1 = s -> {
568
+ try {
569
+ s .onSubscribe (mock (Subscription .class ));
570
+ s .onNext (ByteBuffer .wrap (response1 .getBytes ()));
571
+ s .onComplete ();
572
+ } catch (Throwable e ) {
573
+ s .onError (e );
574
+ }
575
+ };
576
+ mlSdkAsyncHttpResponseHandler1 .onStream (stream1 );
577
+
578
+ SdkHttpFullResponse sdkHttpResponse2 = mock (SdkHttpFullResponse .class );
579
+ when (sdkHttpResponse2 .statusCode ()).thenReturn (HttpStatusCode .BAD_REQUEST );
580
+ when (sdkHttpResponse2 .headers ()).thenReturn (headersMap );
581
+ mlSdkAsyncHttpResponseHandler2 .onHeaders (sdkHttpResponse2 );
582
+ Publisher <ByteBuffer > stream2 = s -> {
583
+ try {
584
+ s .onSubscribe (mock (Subscription .class ));
585
+ s .onNext (ByteBuffer .wrap (response2 .getBytes ()));
586
+ s .onComplete ();
587
+ } catch (Throwable e ) {
588
+ s .onError (e );
589
+ }
590
+ };
591
+ mlSdkAsyncHttpResponseHandler2 .onStream (stream2 );
592
+ ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
593
+ verify (actionListener , times (1 )).onFailure (captor .capture ());
594
+ assert captor .getValue ().getMessage ().equals ("Error from remote service: The request was denied due to remote server throttling." );
595
+ assert captor .getValue ().status ().getStatus () == HttpStatusCode .BAD_REQUEST ;
596
+ }
597
+
422
598
}
0 commit comments