@@ -53,6 +53,8 @@ public class MLSdkAsyncHttpResponseHandlerTest {
53
53
54
54
private Connector noProcessFunctionConnector ;
55
55
56
+ private Map <String , List <String >> headersMap ;
57
+
56
58
@ Mock
57
59
private SdkHttpFullResponse sdkHttpResponse ;
58
60
@ Mock
@@ -106,6 +108,7 @@ public void setup() {
106
108
null
107
109
);
108
110
responseSubscriber = mlSdkAsyncHttpResponseHandler .new MLResponseSubscriber ();
111
+ headersMap = Map .of (AMZ_ERROR_HEADER , Arrays .asList ("ThrottlingException:request throttled!" ));
109
112
}
110
113
111
114
@ Test
@@ -434,7 +437,6 @@ public void test_onComplete_throttle_error_headers() {
434
437
String error = "{\" message\" : null}" ;
435
438
SdkHttpResponse response = mock (SdkHttpFullResponse .class );
436
439
when (response .statusCode ()).thenReturn (HttpStatusCode .BAD_REQUEST );
437
- Map <String , List <String >> headersMap = Map .of (AMZ_ERROR_HEADER , Arrays .asList ("ThrottlingException:request throttled!" ));
438
440
when (response .headers ()).thenReturn (headersMap );
439
441
mlSdkAsyncHttpResponseHandler .onHeaders (response );
440
442
Publisher <ByteBuffer > stream = s -> {
@@ -453,4 +455,144 @@ public void test_onComplete_throttle_error_headers() {
453
455
System .out .println (captor .getValue ().getMessage ());
454
456
assert captor .getValue ().getMessage ().contains (REMOTE_SERVICE_ERROR );
455
457
}
458
+
459
+ @ Test
460
+ public void test_onComplete_throttle_exceptionFirst () {
461
+ AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
462
+ String response1 = "{\n "
463
+ + " \" embedding\" : [\n "
464
+ + " 0.46484375,\n "
465
+ + " -0.017822266,\n "
466
+ + " 0.17382812,\n "
467
+ + " 0.10595703,\n "
468
+ + " 0.875,\n "
469
+ + " 0.19140625,\n "
470
+ + " -0.36914062,\n "
471
+ + " -0.0011978149\n "
472
+ + " ]\n "
473
+ + "}" ;
474
+ String response2 = "{\" message\" : null}" ;
475
+ CountDownLatch count = new CountDownLatch (2 );
476
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler (
477
+ new ExecutionContext (0 , count , exceptionHolder ),
478
+ actionListener ,
479
+ parameters ,
480
+ tensorOutputs ,
481
+ connector ,
482
+ scriptService ,
483
+ null
484
+ );
485
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler (
486
+ new ExecutionContext (1 , count , exceptionHolder ),
487
+ actionListener ,
488
+ parameters ,
489
+ tensorOutputs ,
490
+ connector ,
491
+ scriptService ,
492
+ null
493
+ );
494
+
495
+ SdkHttpFullResponse sdkHttpResponse2 = mock (SdkHttpFullResponse .class );
496
+ when (sdkHttpResponse2 .statusCode ()).thenReturn (HttpStatusCode .BAD_REQUEST );
497
+ when (sdkHttpResponse2 .headers ()).thenReturn (headersMap );
498
+ mlSdkAsyncHttpResponseHandler2 .onHeaders (sdkHttpResponse2 );
499
+ Publisher <ByteBuffer > stream2 = s -> {
500
+ try {
501
+ s .onSubscribe (mock (Subscription .class ));
502
+ s .onNext (ByteBuffer .wrap (response2 .getBytes ()));
503
+ s .onComplete ();
504
+ } catch (Throwable e ) {
505
+ s .onError (e );
506
+ }
507
+ };
508
+ mlSdkAsyncHttpResponseHandler2 .onStream (stream2 );
509
+
510
+ SdkHttpFullResponse sdkHttpResponse1 = mock (SdkHttpFullResponse .class );
511
+ when (sdkHttpResponse1 .statusCode ()).thenReturn (200 );
512
+ mlSdkAsyncHttpResponseHandler1 .onHeaders (sdkHttpResponse1 );
513
+ Publisher <ByteBuffer > stream1 = s -> {
514
+ try {
515
+ s .onSubscribe (mock (Subscription .class ));
516
+ s .onNext (ByteBuffer .wrap (response1 .getBytes ()));
517
+ s .onComplete ();
518
+ } catch (Throwable e ) {
519
+ s .onError (e );
520
+ }
521
+ };
522
+ mlSdkAsyncHttpResponseHandler1 .onStream (stream1 );
523
+ ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
524
+ verify (actionListener , times (1 )).onFailure (captor .capture ());
525
+ assert captor .getValue ().getMessage ().equals ("Error from remote service: The request was denied due to remote server throttling." );
526
+ assert captor .getValue ().status ().getStatus () == HttpStatusCode .BAD_REQUEST ;
527
+ }
528
+
529
+ @ Test
530
+ public void test_onComplete_throttle_exceptionSecond () {
531
+ AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
532
+ String response1 = "{\n "
533
+ + " \" embedding\" : [\n "
534
+ + " 0.46484375,\n "
535
+ + " -0.017822266,\n "
536
+ + " 0.17382812,\n "
537
+ + " 0.10595703,\n "
538
+ + " 0.875,\n "
539
+ + " 0.19140625,\n "
540
+ + " -0.36914062,\n "
541
+ + " -0.0011978149\n "
542
+ + " ]\n "
543
+ + "}" ;
544
+ String response2 = "{\" message\" : null}" ;
545
+ CountDownLatch count = new CountDownLatch (2 );
546
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler (
547
+ new ExecutionContext (0 , count , exceptionHolder ),
548
+ actionListener ,
549
+ parameters ,
550
+ tensorOutputs ,
551
+ connector ,
552
+ scriptService ,
553
+ null
554
+ );
555
+ MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler (
556
+ new ExecutionContext (1 , count , exceptionHolder ),
557
+ actionListener ,
558
+ parameters ,
559
+ tensorOutputs ,
560
+ connector ,
561
+ scriptService ,
562
+ null
563
+ );
564
+ SdkHttpFullResponse sdkHttpResponse1 = mock (SdkHttpFullResponse .class );
565
+ when (sdkHttpResponse1 .statusCode ()).thenReturn (200 );
566
+ mlSdkAsyncHttpResponseHandler1 .onHeaders (sdkHttpResponse1 );
567
+ Publisher <ByteBuffer > stream1 = s -> {
568
+ try {
569
+ s .onSubscribe (mock (Subscription .class ));
570
+ s .onNext (ByteBuffer .wrap (response1 .getBytes ()));
571
+ s .onComplete ();
572
+ } catch (Throwable e ) {
573
+ s .onError (e );
574
+ }
575
+ };
576
+ mlSdkAsyncHttpResponseHandler1 .onStream (stream1 );
577
+
578
+ SdkHttpFullResponse sdkHttpResponse2 = mock (SdkHttpFullResponse .class );
579
+ when (sdkHttpResponse2 .statusCode ()).thenReturn (HttpStatusCode .BAD_REQUEST );
580
+ when (sdkHttpResponse2 .headers ()).thenReturn (headersMap );
581
+ mlSdkAsyncHttpResponseHandler2 .onHeaders (sdkHttpResponse2 );
582
+ Publisher <ByteBuffer > stream2 = s -> {
583
+ try {
584
+ s .onSubscribe (mock (Subscription .class ));
585
+ s .onNext (ByteBuffer .wrap (response2 .getBytes ()));
586
+ s .onComplete ();
587
+ } catch (Throwable e ) {
588
+ s .onError (e );
589
+ }
590
+ };
591
+ mlSdkAsyncHttpResponseHandler2 .onStream (stream2 );
592
+ ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
593
+ verify (actionListener , times (1 )).onFailure (captor .capture ());
594
+ assert captor .getValue ().getMessage ().equals ("Error from remote service: The request was denied due to remote server throttling." );
595
+ assert captor .getValue ().status ().getStatus () == HttpStatusCode .BAD_REQUEST ;
596
+ }
597
+
456
598
}
0 commit comments