Skip to content

Commit b553d8e

Browse files
committed
renamed files and added more tests
Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
1 parent d0a9973 commit b553d8e

15 files changed

+483
-125
lines changed

ml-algorithms/build.gradle

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ dependencies {
7373
}
7474
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
7575

76-
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
77-
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
78-
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
76+
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
77+
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
78+
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
7979

8080
implementation 'com.jayway.jsonpath:json-path:2.9.0'
8181
implementation group: 'org.json', name: 'json', version: '20231013'

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/S3Utils.java

+2-12
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,8 @@ public static S3Client initS3Client(String accessKey, String secretKey, String s
4444
}
4545
}
4646

47-
public static void putObject(
48-
String accessKey,
49-
String secretKey,
50-
String sessionToken,
51-
String region,
52-
String bucketName,
53-
String key,
54-
String content
55-
) {
56-
try (S3Client s3Client = initS3Client(accessKey, secretKey, sessionToken, region)) {
47+
public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
48+
try {
5749
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
5850
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();
5951

@@ -63,8 +55,6 @@ public static void putObject(
6355
});
6456
} catch (PrivilegedActionException e) {
6557
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
66-
} catch (Exception e) {
67-
log.error("Unexpected error during S3 upload", e);
6858
}
6959
}
7060

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.utils;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
import static org.mockito.ArgumentMatchers.any;
11+
import static org.mockito.Mockito.*;
12+
13+
import org.junit.Before;
14+
import org.junit.Rule;
15+
import org.junit.Test;
16+
import org.junit.rules.ExpectedException;
17+
import org.mockito.Mock;
18+
import org.mockito.MockitoAnnotations;
19+
20+
import software.amazon.awssdk.core.sync.RequestBody;
21+
import software.amazon.awssdk.services.s3.S3Client;
22+
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
23+
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
24+
25+
public class S3UtilsTest {
26+
27+
@Mock
28+
private S3Client s3Client;
29+
30+
@Rule
31+
public ExpectedException exceptionRule = ExpectedException.none();
32+
33+
@Before
34+
public void setUp() {
35+
MockitoAnnotations.openMocks(this);
36+
}
37+
38+
@Test
39+
public void testInitS3Client() {
40+
String accessKey = "test-access-key";
41+
String secretKey = "test-secret-key";
42+
String sessionToken = "test-session-token";
43+
String region = "us-west-2";
44+
45+
S3Client client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region);
46+
assertNotNull(client);
47+
}
48+
49+
@Test
50+
public void testInitS3ClientWithoutSessionToken() {
51+
String accessKey = "test-access-key";
52+
String secretKey = "test-secret-key";
53+
String region = "us-west-2";
54+
55+
S3Client client = S3Utils.initS3Client(accessKey, secretKey, null, region);
56+
assertNotNull(client);
57+
}
58+
59+
@Test
60+
public void testPutObject() {
61+
String bucketName = "test-bucket";
62+
String key = "test-key";
63+
String content = "test-content";
64+
65+
when(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))).thenReturn(PutObjectResponse.builder().build());
66+
67+
S3Utils.putObject(s3Client, bucketName, key, content);
68+
69+
verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), any(RequestBody.class));
70+
}
71+
72+
@Test
73+
public void testGetS3BucketName() {
74+
String s3Uri = "s3://test-bucket/path/to/file";
75+
assertEquals("test-bucket", S3Utils.getS3BucketName(s3Uri));
76+
77+
s3Uri = "s3://test-bucket";
78+
assertEquals("test-bucket", S3Utils.getS3BucketName(s3Uri));
79+
}
80+
81+
@Test
82+
public void testGetS3KeyName() {
83+
String s3Uri = "s3://test-bucket/path/to/file";
84+
assertEquals("path/to/file", S3Utils.getS3KeyName(s3Uri));
85+
86+
s3Uri = "s3://test-bucket";
87+
assertEquals("", S3Utils.getS3KeyName(s3Uri));
88+
}
89+
}

plugin/build.gradle

+13-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ dependencies {
5959
implementation project(':opensearch-ml-memory')
6060
compileOnly "com.google.guava:guava:32.1.3-jre"
6161

62+
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
63+
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
64+
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
65+
66+
implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: '2.29.12'
67+
68+
implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: '2.29.12'
69+
70+
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: '2.29.12'
71+
6272
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
6373
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
6474
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
@@ -353,7 +363,9 @@ List<String> jacocoExclusions = [
353363
'org.opensearch.ml.action.models.DeleteModelTransportAction.2',
354364
'org.opensearch.ml.model.MLModelCacheHelper',
355365
'org.opensearch.ml.model.MLModelCacheHelper.1',
356-
'org.opensearch.ml.action.tasks.CancelBatchJobTransportAction'
366+
'org.opensearch.ml.action.tasks.CancelBatchJobTransportAction',
367+
'org.opensearch.ml.jobs.MLBatchTaskUpdateExtension',
368+
'org.opensearch.ml.jobs.MLBatchTaskUpdateJobRunner'
357369

358370
]
359371

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

+29-17
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
import com.google.common.annotations.VisibleForTesting;
103103

104104
import lombok.extern.log4j.Log4j2;
105+
import software.amazon.awssdk.services.s3.S3Client;
106+
import software.amazon.awssdk.services.s3.model.S3Exception;
105107

106108
@Log4j2
107109
public class GetTaskTransportAction extends HandledTransportAction<ActionRequest, MLTaskGetResponse> {
@@ -129,7 +131,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
129131
volatile Pattern remoteJobFailedStatusRegexPattern;
130132
private final MLEngine mlEngine;
131133

132-
private Map<String, String> decryptedCredential;
134+
// private Map<String, String> decryptedCredential;
133135

134136
@Inject
135137
public GetTaskTransportAction(
@@ -456,19 +458,25 @@ private void executeConnector(
456458
connector.addAction(connectorAction);
457459
}
458460

459-
decryptedCredential = connector.getDecryptedCredential();
460-
461-
if (decryptedCredential == null || decryptedCredential.isEmpty()) {
462-
decryptedCredential = mlEngine.getConnectorCredential(connector);
463-
}
464-
461+
final Map<String, String> decryptedCredential = connector.getDecryptedCredential() != null
462+
&& !connector.getDecryptedCredential().isEmpty()
463+
? mlEngine.getConnectorCredential(connector)
464+
: connector.getDecryptedCredential();
465465
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
466466
connectorExecutor.setScriptService(scriptService);
467467
connectorExecutor.setClusterService(clusterService);
468468
connectorExecutor.setClient(client);
469469
connectorExecutor.setXContentRegistry(xContentRegistry);
470470
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
471-
processTaskResponse(mlTask, taskId, isUserInitiatedGetTaskRequest, taskResponse, remoteJob, actionListener);
471+
processTaskResponse(
472+
mlTask,
473+
taskId,
474+
isUserInitiatedGetTaskRequest,
475+
taskResponse,
476+
remoteJob,
477+
decryptedCredential,
478+
actionListener
479+
);
472480
}, e -> {
473481
// When the request to remote service fails, we will retry the request for next 10 minutes (10 runs).
474482
// If it fails even then, we mark it as unreachable in task index and send message to DLQ
@@ -500,6 +508,7 @@ protected void processTaskResponse(
500508
Boolean isUserInitiatedGetTaskRequest,
501509
MLTaskResponse taskResponse,
502510
Map<String, Object> remoteJob,
511+
Map<String, String> decryptedCredential,
503512
ActionListener<MLTaskGetResponse> actionListener
504513
) {
505514
try {
@@ -566,15 +575,18 @@ protected void updateDLQ(MLTask mlTask, Map<String, String> decryptedCredential)
566575
log.error("Failed to get the bucket name and region from batch predict request");
567576
}
568577
remoteJobDetails.remove("dlq");
569-
570-
String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name"));
571-
String s3ObjectKey = "BatchJobFailure_" + jobName;
572-
String content = mlTask.getState().equals(UNREACHABLE)
573-
? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError())
574-
: remoteJobDetails.toString();
575-
576-
S3Utils.putObject(accessKey, secretKey, sessionToken, region, bucketName, s3ObjectKey, content);
577-
log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now());
578+
try (S3Client s3Client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region)) {
579+
String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name"));
580+
String s3ObjectKey = "BatchJobFailure_" + jobName;
581+
String content = mlTask.getState().equals(UNREACHABLE)
582+
? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError())
583+
: remoteJobDetails.toString();
584+
585+
S3Utils.putObject(s3Client, bucketName, s3ObjectKey, content);
586+
log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now());
587+
}
588+
} catch (S3Exception e) {
589+
log.error("Failed to update task status for task: {}. S3 Exception: {}", taskId, e.awsErrorDetails().errorMessage());
578590
} catch (Exception e) {
579591
log.error("Failed to update task status for task: " + taskId, e);
580592
}

plugin/src/main/java/org/opensearch/ml/jobs/BatchPredictTaskUpdateJob.java plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateExtension.java

+12-10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.ml.jobs;
77

8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
810
import java.io.IOException;
911
import java.time.Instant;
1012

@@ -16,7 +18,7 @@
1618
import org.opensearch.jobscheduler.spi.schedule.ScheduleParser;
1719
import org.opensearch.ml.common.CommonValue;
1820

19-
public class BatchPredictTaskUpdateJob implements JobSchedulerExtension {
21+
public class MLBatchTaskUpdateExtension implements JobSchedulerExtension {
2022

2123
@Override
2224
public String getJobType() {
@@ -31,32 +33,32 @@ public ScheduledJobRunner getJobRunner() {
3133
@Override
3234
public ScheduledJobParser getJobParser() {
3335
return (parser, id, jobDocVersion) -> {
34-
MLBatchPredictTaskUpdateJobParameter jobParameter = new MLBatchPredictTaskUpdateJobParameter();
35-
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
36+
MLBatchTaskUpdateJobParameter jobParameter = new MLBatchTaskUpdateJobParameter();
37+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
3638

3739
while (!parser.nextToken().equals(XContentParser.Token.END_OBJECT)) {
3840
String fieldName = parser.currentName();
3941
parser.nextToken();
4042
switch (fieldName) {
41-
case MLBatchPredictTaskUpdateJobParameter.NAME_FIELD:
43+
case MLBatchTaskUpdateJobParameter.NAME_FIELD:
4244
jobParameter.setJobName(parser.text());
4345
break;
44-
case MLBatchPredictTaskUpdateJobParameter.ENABLED_FILED:
46+
case MLBatchTaskUpdateJobParameter.ENABLED_FILED:
4547
jobParameter.setEnabled(parser.booleanValue());
4648
break;
47-
case MLBatchPredictTaskUpdateJobParameter.ENABLED_TIME_FILED:
49+
case MLBatchTaskUpdateJobParameter.ENABLED_TIME_FILED:
4850
jobParameter.setEnabledTime(parseInstantValue(parser));
4951
break;
50-
case MLBatchPredictTaskUpdateJobParameter.LAST_UPDATE_TIME_FIELD:
52+
case MLBatchTaskUpdateJobParameter.LAST_UPDATE_TIME_FIELD:
5153
jobParameter.setLastUpdateTime(parseInstantValue(parser));
5254
break;
53-
case MLBatchPredictTaskUpdateJobParameter.SCHEDULE_FIELD:
55+
case MLBatchTaskUpdateJobParameter.SCHEDULE_FIELD:
5456
jobParameter.setSchedule(ScheduleParser.parse(parser));
5557
break;
56-
case MLBatchPredictTaskUpdateJobParameter.LOCK_DURATION_SECONDS:
58+
case MLBatchTaskUpdateJobParameter.LOCK_DURATION_SECONDS:
5759
jobParameter.setLockDurationSeconds(parser.longValue());
5860
break;
59-
case MLBatchPredictTaskUpdateJobParameter.JITTER:
61+
case MLBatchTaskUpdateJobParameter.JITTER:
6062
jobParameter.setJitter(parser.doubleValue());
6163
break;
6264
default:

plugin/src/main/java/org/opensearch/ml/jobs/MLBatchPredictTaskUpdateJobParameter.java plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobParameter.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
* It adds an additional "indexToWatch" field to {@link ScheduledJobParameter}, which stores the index
1919
* the job runner will watch.
2020
*/
21-
public class MLBatchPredictTaskUpdateJobParameter implements ScheduledJobParameter {
21+
public class MLBatchTaskUpdateJobParameter implements ScheduledJobParameter {
2222
public static final String NAME_FIELD = "name";
2323
public static final String ENABLED_FILED = "enabled";
2424
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
@@ -38,9 +38,9 @@ public class MLBatchPredictTaskUpdateJobParameter implements ScheduledJobParamet
3838
private Long lockDurationSeconds;
3939
private Double jitter;
4040

41-
public MLBatchPredictTaskUpdateJobParameter() {}
41+
public MLBatchTaskUpdateJobParameter() {}
4242

43-
public MLBatchPredictTaskUpdateJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter) {
43+
public MLBatchTaskUpdateJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter) {
4444
this.jobName = name;
4545
this.schedule = schedule;
4646
this.lockDurationSeconds = lockDurationSeconds;

0 commit comments

Comments
 (0)