|
102 | 102 | import com.google.common.annotations.VisibleForTesting;
|
103 | 103 |
|
104 | 104 | import lombok.extern.log4j.Log4j2;
|
| 105 | +import software.amazon.awssdk.services.s3.S3Client; |
| 106 | +import software.amazon.awssdk.services.s3.model.S3Exception; |
105 | 107 |
|
106 | 108 | @Log4j2
|
107 | 109 | public class GetTaskTransportAction extends HandledTransportAction<ActionRequest, MLTaskGetResponse> {
|
@@ -129,7 +131,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
|
129 | 131 | volatile Pattern remoteJobFailedStatusRegexPattern;
|
130 | 132 | private final MLEngine mlEngine;
|
131 | 133 |
|
132 |
| - private Map<String, String> decryptedCredential; |
| 134 | + // private Map<String, String> decryptedCredential; |
133 | 135 |
|
134 | 136 | @Inject
|
135 | 137 | public GetTaskTransportAction(
|
@@ -456,19 +458,25 @@ private void executeConnector(
|
456 | 458 | connector.addAction(connectorAction);
|
457 | 459 | }
|
458 | 460 |
|
459 |
| - decryptedCredential = connector.getDecryptedCredential(); |
460 |
| - |
461 |
| - if (decryptedCredential == null || decryptedCredential.isEmpty()) { |
462 |
| - decryptedCredential = mlEngine.getConnectorCredential(connector); |
463 |
| - } |
464 |
| - |
| 461 | + final Map<String, String> decryptedCredential = connector.getDecryptedCredential() != null |
| 462 | + && !connector.getDecryptedCredential().isEmpty() |
| 463 | + ? mlEngine.getConnectorCredential(connector) |
| 464 | + : connector.getDecryptedCredential(); |
465 | 465 | RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
|
466 | 466 | connectorExecutor.setScriptService(scriptService);
|
467 | 467 | connectorExecutor.setClusterService(clusterService);
|
468 | 468 | connectorExecutor.setClient(client);
|
469 | 469 | connectorExecutor.setXContentRegistry(xContentRegistry);
|
470 | 470 | connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
|
471 |
| - processTaskResponse(mlTask, taskId, isUserInitiatedGetTaskRequest, taskResponse, remoteJob, actionListener); |
| 471 | + processTaskResponse( |
| 472 | + mlTask, |
| 473 | + taskId, |
| 474 | + isUserInitiatedGetTaskRequest, |
| 475 | + taskResponse, |
| 476 | + remoteJob, |
| 477 | + decryptedCredential, |
| 478 | + actionListener |
| 479 | + ); |
472 | 480 | }, e -> {
|
473 | 481 | // When the request to remote service fails, we will retry the request for next 10 minutes (10 runs).
|
474 | 482 | // If it fails even then, we mark it as unreachable in task index and send message to DLQ
|
@@ -500,6 +508,7 @@ protected void processTaskResponse(
|
500 | 508 | Boolean isUserInitiatedGetTaskRequest,
|
501 | 509 | MLTaskResponse taskResponse,
|
502 | 510 | Map<String, Object> remoteJob,
|
| 511 | + Map<String, String> decryptedCredential, |
503 | 512 | ActionListener<MLTaskGetResponse> actionListener
|
504 | 513 | ) {
|
505 | 514 | try {
|
@@ -566,15 +575,18 @@ protected void updateDLQ(MLTask mlTask, Map<String, String> decryptedCredential)
|
566 | 575 | log.error("Failed to get the bucket name and region from batch predict request");
|
567 | 576 | }
|
568 | 577 | remoteJobDetails.remove("dlq");
|
569 |
| - |
570 |
| - String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name")); |
571 |
| - String s3ObjectKey = "BatchJobFailure_" + jobName; |
572 |
| - String content = mlTask.getState().equals(UNREACHABLE) |
573 |
| - ? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError()) |
574 |
| - : remoteJobDetails.toString(); |
575 |
| - |
576 |
| - S3Utils.putObject(accessKey, secretKey, sessionToken, region, bucketName, s3ObjectKey, content); |
577 |
| - log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now()); |
| 578 | + try (S3Client s3Client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region)) { |
| 579 | + String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name")); |
| 580 | + String s3ObjectKey = "BatchJobFailure_" + jobName; |
| 581 | + String content = mlTask.getState().equals(UNREACHABLE) |
| 582 | + ? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError()) |
| 583 | + : remoteJobDetails.toString(); |
| 584 | + |
| 585 | + S3Utils.putObject(s3Client, bucketName, s3ObjectKey, content); |
| 586 | + log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now()); |
| 587 | + } |
| 588 | + } catch (S3Exception e) { |
| 589 | + log.error("Failed to update task status for task: {}. S3 Exception: {}", taskId, e.awsErrorDetails().errorMessage()); |
578 | 590 | } catch (Exception e) {
|
579 | 591 | log.error("Failed to update task status for task: " + taskId, e);
|
580 | 592 | }
|
|
0 commit comments