11
11
import static org .opensearch .ml .common .CommonValue .ML_TASK_INDEX ;
12
12
import static org .opensearch .ml .common .MLTask .REMOTE_JOB_FIELD ;
13
13
import static org .opensearch .ml .common .MLTask .STATE_FIELD ;
14
- import static org .opensearch .ml .common .MLTaskState .*;
15
- import static org .opensearch .ml .common .connector .AbstractConnector .*;
14
+ import static org .opensearch .ml .common .MLTaskState .CANCELLED ;
15
+ import static org .opensearch .ml .common .MLTaskState .CANCELLING ;
16
+ import static org .opensearch .ml .common .MLTaskState .COMPLETED ;
17
+ import static org .opensearch .ml .common .MLTaskState .EXPIRED ;
18
+ import static org .opensearch .ml .common .MLTaskState .FAILED ;
19
+ import static org .opensearch .ml .common .MLTaskState .UNREACHABLE ;
20
+ import static org .opensearch .ml .common .connector .AbstractConnector .ACCESS_KEY_FIELD ;
21
+ import static org .opensearch .ml .common .connector .AbstractConnector .SESSION_TOKEN_FIELD ;
22
+ import static org .opensearch .ml .common .connector .AbstractConnector .SECRET_KEY_FIELD ;
16
23
import static org .opensearch .ml .common .connector .ConnectorAction .ActionType .BATCH_PREDICT_STATUS ;
17
- import static org .opensearch .ml .settings .MLCommonsSettings .*;
24
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX ;
25
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX ;
26
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX ;
27
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX ;
28
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX ;
29
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_FIELD ;
18
30
import static org .opensearch .ml .utils .MLExceptionUtils .BATCH_INFERENCE_DISABLED_ERR_MSG ;
19
31
import static org .opensearch .ml .utils .MLExceptionUtils .logException ;
20
32
27
39
import java .util .regex .Matcher ;
28
40
import java .util .regex .Pattern ;
29
41
42
+ import com .google .common .annotations .VisibleForTesting ;
30
43
import org .opensearch .ExceptionsHelper ;
31
44
import org .opensearch .OpenSearchException ;
32
45
import org .opensearch .OpenSearchStatusException ;
@@ -440,6 +453,7 @@ private void executeConnector(
440
453
}
441
454
442
455
decryptedCredential = connector .getDecryptedCredential ();
456
+
443
457
if (decryptedCredential == null || decryptedCredential .isEmpty ()) {
444
458
decryptedCredential = mlEngine .getConnectorCredential (connector );
445
459
}
@@ -449,7 +463,7 @@ private void executeConnector(
449
463
connectorExecutor .setClient (client );
450
464
connectorExecutor .setXContentRegistry (xContentRegistry );
451
465
connectorExecutor .executeAction (BATCH_PREDICT_STATUS .name (), mlInput , ActionListener .wrap (taskResponse -> {
452
- processTaskResponse (mlTask , taskId , isUserInitiatedGetTaskRequest , taskResponse , remoteJob , actionListener );
466
+ processTaskResponse (mlTask , taskId , isUserInitiatedGetTaskRequest , taskResponse , remoteJob , actionListener );
453
467
}, e -> {
454
468
// When the request to remote service fails, we will retry the request for next 10 minutes (10 runs).
455
469
// If it fails even then, we mark it as unreachable in task index and send message to DLQ
@@ -466,7 +480,7 @@ private void executeConnector(
466
480
updatedTask .put (STATE_FIELD , UNREACHABLE );
467
481
mlTask .setState (UNREACHABLE );
468
482
mlTask .setError (e .getMessage ());
469
- updateDLQ (mlTask );
483
+ updateDLQ (mlTask , decryptedCredential );
470
484
}
471
485
updatedTask .put ("remote_job" , remoteJob );
472
486
mlTaskManager .updateMLTaskDirectly (taskId , updatedTask );
@@ -504,7 +518,7 @@ protected void processTaskResponse(
504
518
505
519
mlTaskManager .updateMLTaskDirectly (taskId , updatedTask , ActionListener .wrap (response -> {
506
520
if (mlTask .getState ().equals (FAILED ) && !isUserInitiatedGetTaskRequest ) {
507
- updateDLQ (mlTask );
521
+ updateDLQ (mlTask , decryptedCredential );
508
522
}
509
523
actionListener .onResponse (MLTaskGetResponse .builder ().mlTask (mlTask ).build ());
510
524
}, e -> {
@@ -528,16 +542,17 @@ protected void processTaskResponse(
528
542
}
529
543
}
530
544
531
- protected void updateDLQ (MLTask mlTask ) {
545
+ @ VisibleForTesting
546
+ protected void updateDLQ (MLTask mlTask , Map <String , String > decryptedCredential ) {
532
547
Map <String , Object > remoteJob = mlTask .getRemoteJob ();
533
548
Map <String , String > dlq = (Map <String , String >) remoteJob .get ("dlq" );
534
549
if (dlq != null && !dlq .isEmpty ()) {
535
550
String taskId = mlTask .getTaskId ();
536
551
try {
537
552
Map <String , Object > remoteJobDetails = mlTask .getRemoteJob ();
538
- String accessKey = this . decryptedCredential .get (ACCESS_KEY_FIELD );
539
- String secretKey = this . decryptedCredential .get (SECRET_KEY_FIELD );
540
- String sessionToken = this . decryptedCredential .get (SESSION_TOKEN_FIELD );
553
+ String accessKey = decryptedCredential .get (ACCESS_KEY_FIELD );
554
+ String secretKey = decryptedCredential .get (SECRET_KEY_FIELD );
555
+ String sessionToken = decryptedCredential .get (SESSION_TOKEN_FIELD );
541
556
542
557
String bucketName = dlq .get ("bucket" );
543
558
String region = dlq .get ("region" );
0 commit comments