Skip to content

Commit 2f84d8c

Browse files
committed
add unit tests
Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
1 parent 607edbe commit 2f84d8c

12 files changed

+374
-37
lines changed

ml-algorithms/build.gradle

-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ dependencies {
7070
}
7171
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
7272

73-
74-
7573
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
7674
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
7775
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/S3Utils.java

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ public static void putObject(
5353
String key,
5454
String content
5555
) {
56-
5756
try (S3Client s3Client = initS3Client(accessKey, secretKey, sessionToken, region)) {
5857
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
5958
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();

plugin/build.gradle

-10
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ dependencies {
5959
implementation project(':opensearch-ml-memory')
6060
compileOnly "com.google.guava:guava:32.1.3-jre"
6161

62-
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
63-
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
64-
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
65-
66-
implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: '2.29.12'
67-
68-
implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: '2.29.12'
69-
70-
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: '2.29.12'
71-
7262
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
7363
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
7464
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

+29-10
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,22 @@
1111
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
1212
import static org.opensearch.ml.common.MLTask.REMOTE_JOB_FIELD;
1313
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
14-
import static org.opensearch.ml.common.MLTaskState.*;
15-
import static org.opensearch.ml.common.connector.AbstractConnector.*;
14+
import static org.opensearch.ml.common.MLTaskState.CANCELLED;
15+
import static org.opensearch.ml.common.MLTaskState.CANCELLING;
16+
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
17+
import static org.opensearch.ml.common.MLTaskState.EXPIRED;
18+
import static org.opensearch.ml.common.MLTaskState.FAILED;
19+
import static org.opensearch.ml.common.MLTaskState.UNREACHABLE;
20+
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
21+
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
22+
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
1623
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS;
17-
import static org.opensearch.ml.settings.MLCommonsSettings.*;
24+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
25+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
26+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
27+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
28+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX;
29+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
1830
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
1931
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
2032

@@ -49,7 +61,10 @@
4961
import org.opensearch.core.xcontent.NamedXContentRegistry;
5062
import org.opensearch.core.xcontent.XContentParser;
5163
import org.opensearch.index.IndexNotFoundException;
52-
import org.opensearch.ml.common.*;
64+
import org.opensearch.ml.common.FunctionName;
65+
import org.opensearch.ml.common.MLModel;
66+
import org.opensearch.ml.common.MLTask;
67+
import org.opensearch.ml.common.MLTaskType;
5368
import org.opensearch.ml.common.connector.Connector;
5469
import org.opensearch.ml.common.connector.ConnectorAction;
5570
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
@@ -84,6 +99,8 @@
8499
import org.opensearch.tasks.Task;
85100
import org.opensearch.transport.TransportService;
86101

102+
import com.google.common.annotations.VisibleForTesting;
103+
87104
import lombok.extern.log4j.Log4j2;
88105

89106
@Log4j2
@@ -440,6 +457,7 @@ private void executeConnector(
440457
}
441458

442459
decryptedCredential = connector.getDecryptedCredential();
460+
443461
if (decryptedCredential == null || decryptedCredential.isEmpty()) {
444462
decryptedCredential = mlEngine.getConnectorCredential(connector);
445463
}
@@ -466,7 +484,7 @@ private void executeConnector(
466484
updatedTask.put(STATE_FIELD, UNREACHABLE);
467485
mlTask.setState(UNREACHABLE);
468486
mlTask.setError(e.getMessage());
469-
updateDLQ(mlTask);
487+
updateDLQ(mlTask, decryptedCredential);
470488
}
471489
updatedTask.put("remote_job", remoteJob);
472490
mlTaskManager.updateMLTaskDirectly(taskId, updatedTask);
@@ -504,7 +522,7 @@ protected void processTaskResponse(
504522

505523
mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> {
506524
if (mlTask.getState().equals(FAILED) && !isUserInitiatedGetTaskRequest) {
507-
updateDLQ(mlTask);
525+
updateDLQ(mlTask, decryptedCredential);
508526
}
509527
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
510528
}, e -> {
@@ -528,16 +546,17 @@ protected void processTaskResponse(
528546
}
529547
}
530548

531-
protected void updateDLQ(MLTask mlTask) {
549+
@VisibleForTesting
550+
protected void updateDLQ(MLTask mlTask, Map<String, String> decryptedCredential) {
532551
Map<String, Object> remoteJob = mlTask.getRemoteJob();
533552
Map<String, String> dlq = (Map<String, String>) remoteJob.get("dlq");
534553
if (dlq != null && !dlq.isEmpty()) {
535554
String taskId = mlTask.getTaskId();
536555
try {
537556
Map<String, Object> remoteJobDetails = mlTask.getRemoteJob();
538-
String accessKey = this.decryptedCredential.get(ACCESS_KEY_FIELD);
539-
String secretKey = this.decryptedCredential.get(SECRET_KEY_FIELD);
540-
String sessionToken = this.decryptedCredential.get(SESSION_TOKEN_FIELD);
557+
String accessKey = decryptedCredential.get(ACCESS_KEY_FIELD);
558+
String secretKey = decryptedCredential.get(SECRET_KEY_FIELD);
559+
String sessionToken = decryptedCredential.get(SESSION_TOKEN_FIELD);
541560

542561
String bucketName = dlq.get("bucket");
543562
String region = dlq.get("region");

plugin/src/main/java/org/opensearch/ml/jobs/BatchPredictTaskUpdateJob.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.jobs;
27

38
import java.io.IOException;

plugin/src/main/java/org/opensearch/ml/jobs/MLBatchPredictTaskUpdateJobParameter.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/*
22
* Copyright OpenSearch Contributors
33
* SPDX-License-Identifier: Apache-2.0
4-
*
5-
* The OpenSearch Contributors require contributions made to
6-
* this file be licensed under the Apache-2.0 license or a
7-
* compatible open source license.
84
*/
5+
96
package org.opensearch.ml.jobs;
107

118
import java.io.IOException;

plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunner.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.jobs;
27

38
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
@@ -129,10 +134,7 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont
129134

130135
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(taskResponse -> {
131136
log.info("Updated Task status for taskId: {} at {}", taskId, Instant.now());
132-
}, exception -> {
133-
log.error("Failed to get task status for task: " + taskId, exception);
134-
135-
}));
137+
}, exception -> { log.error("Failed to get task status for task: " + taskId, exception); }));
136138
}
137139
}, e -> {
138140
if (e instanceof IndexNotFoundException) {

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
import org.opensearch.core.xcontent.XContentParser;
4747
import org.opensearch.ml.breaker.MLCircuitBreakerService;
4848
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
49-
import org.opensearch.ml.common.*;
49+
import org.opensearch.ml.common.FunctionName;
50+
import org.opensearch.ml.common.MLModel;
51+
import org.opensearch.ml.common.MLTask;
52+
import org.opensearch.ml.common.MLTaskState;
53+
import org.opensearch.ml.common.MLTaskType;
5054
import org.opensearch.ml.common.connector.ConnectorAction;
5155
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
5256
import org.opensearch.ml.common.dataset.MLInputDataType;

plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java

+84-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
import static org.mockito.Mockito.spy;
1616
import static org.mockito.Mockito.verify;
1717
import static org.mockito.Mockito.when;
18-
import static org.opensearch.ml.settings.MLCommonsSettings.*;
18+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
19+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
20+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
21+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
22+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX;
23+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
1924

2025
import java.io.IOException;
2126
import java.util.Arrays;
@@ -67,7 +72,6 @@
6772
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
6873
import org.opensearch.ml.engine.MLEngine;
6974
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
70-
import org.opensearch.ml.engine.ingest.S3DataIngestion;
7175
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
7276
import org.opensearch.ml.helper.ModelAccessControlHelper;
7377
import org.opensearch.ml.model.MLModelManager;
@@ -120,9 +124,6 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase {
120124
@Mock
121125
private MLModelManager mlModelManager;
122126

123-
@Mock
124-
private S3DataIngestion s3DataIngestion;
125-
126127
@Mock
127128
private MLTaskManager mlTaskManager;
128129

@@ -460,6 +461,10 @@ public void test_processTaskResponse_expired() {
460461
processTaskResponse("status", "expired", MLTaskState.EXPIRED);
461462
}
462463

464+
public void test_processTaskResponse_failed() {
465+
processTaskResponse("status", "failed", MLTaskState.FAILED);
466+
}
467+
463468
public void test_processTaskResponse_WrongStatusField() {
464469
processTaskResponse("wrong_status_field", "expired", null);
465470
}
@@ -500,4 +505,78 @@ private void processTaskResponse(String statusField, String remoteJobResponseSta
500505
assertEquals(remoteJobResponseStatus, updatedRemoteJob.get(statusField));
501506
assertEquals(remoteJobName, updatedRemoteJob.get("name"));
502507
}
508+
509+
public void testUpdateDLQ_Success() throws IOException {
510+
// Setup test data
511+
Map<String, Object> remoteJob = new HashMap<>();
512+
remoteJob.put("TransformJobName", "test-job");
513+
Map<String, String> dlq = new HashMap<>();
514+
dlq.put("bucket", "test-bucket");
515+
dlq.put("region", "us-west-2");
516+
remoteJob.put("dlq", dlq);
517+
518+
MLTask mlTask = MLTask
519+
.builder()
520+
.taskId("test-task")
521+
.state(MLTaskState.FAILED)
522+
.error("Test error message")
523+
.remoteJob(remoteJob)
524+
.build();
525+
526+
// Setup decrypted credentials
527+
Map<String, String> decryptedCredential = new HashMap<>();
528+
decryptedCredential.put("aws_access_key", "test-key");
529+
decryptedCredential.put("aws_secret_key", "test-secret");
530+
decryptedCredential.put("aws_session_token", "test-token");
531+
532+
// Call the method
533+
getTaskTransportAction.updateDLQ(mlTask, decryptedCredential);
534+
535+
// Verify remoteJob DLQ is removed
536+
assertNull(mlTask.getRemoteJob().get("dlq"));
537+
}
538+
539+
public void testUpdateDLQ_MissingBucketOrRegion() {
540+
// Setup test data with missing bucket/region
541+
Map<String, Object> remoteJob = new HashMap<>();
542+
remoteJob.put("TransformJobName", "test-job");
543+
Map<String, String> dlq = new HashMap<>();
544+
// Intentionally missing bucket and region
545+
remoteJob.put("dlq", dlq);
546+
547+
MLTask mlTask = MLTask
548+
.builder()
549+
.taskId("test-task")
550+
.state(MLTaskState.FAILED)
551+
.error("Test error message")
552+
.remoteJob(remoteJob)
553+
.build();
554+
555+
// Call the method - should not throw exception but log error
556+
getTaskTransportAction.updateDLQ(mlTask, Collections.emptyMap());
557+
558+
// Verify DLQ still exists since update failed
559+
assertNotNull(mlTask.getRemoteJob().get("dlq"));
560+
}
561+
562+
public void testUpdateDLQ_NullDLQ() {
563+
// Setup test data with null DLQ
564+
Map<String, Object> remoteJob = new HashMap<>();
565+
remoteJob.put("TransformJobName", "test-job");
566+
// No DLQ configuration
567+
568+
MLTask mlTask = MLTask
569+
.builder()
570+
.taskId("test-task")
571+
.state(MLTaskState.FAILED)
572+
.error("Test error message")
573+
.remoteJob(remoteJob)
574+
.build();
575+
576+
// Call the method - should do nothing
577+
getTaskTransportAction.updateDLQ(mlTask, null);
578+
579+
// Verify remoteJob is unchanged
580+
assertEquals("test-job", mlTask.getRemoteJob().get("TransformJobName"));
581+
}
503582
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.jobs;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
import static org.junit.Assert.assertTrue;
11+
12+
import java.time.Instant;
13+
import java.time.temporal.ChronoUnit;
14+
15+
import org.junit.Before;
16+
import org.junit.Test;
17+
import org.opensearch.common.xcontent.XContentFactory;
18+
import org.opensearch.core.xcontent.XContentBuilder;
19+
import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule;
20+
21+
public class MLBatchPredictTaskUpdateJobParameterTests {
22+
23+
private MLBatchPredictTaskUpdateJobParameter jobParameter;
24+
private String jobName;
25+
private IntervalSchedule schedule;
26+
private Long lockDurationSeconds;
27+
private Double jitter;
28+
29+
@Before
30+
public void setUp() {
31+
jobName = "test-job";
32+
schedule = new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES);
33+
lockDurationSeconds = 20L;
34+
jitter = 0.5;
35+
jobParameter = new MLBatchPredictTaskUpdateJobParameter(jobName, schedule, lockDurationSeconds, jitter);
36+
}
37+
38+
@Test
39+
public void testConstructor() {
40+
assertNotNull(jobParameter);
41+
assertEquals(jobName, jobParameter.getName());
42+
assertEquals(schedule, jobParameter.getSchedule());
43+
assertEquals(lockDurationSeconds, jobParameter.getLockDurationSeconds());
44+
assertEquals(jitter, jobParameter.getJitter());
45+
assertTrue(jobParameter.isEnabled());
46+
assertNotNull(jobParameter.getEnabledTime());
47+
assertNotNull(jobParameter.getLastUpdateTime());
48+
}
49+
50+
@Test
51+
public void testToXContent() throws Exception {
52+
XContentBuilder builder = XContentFactory.jsonBuilder();
53+
jobParameter.toXContent(builder, null);
54+
String jsonString = builder.toString();
55+
56+
assertTrue(jsonString.contains(jobName));
57+
assertTrue(jsonString.contains("enabled"));
58+
assertTrue(jsonString.contains("schedule"));
59+
assertTrue(jsonString.contains("lock_duration_seconds"));
60+
assertTrue(jsonString.contains("jitter"));
61+
}
62+
63+
@Test
64+
public void testSetters() {
65+
String newJobName = "new-job";
66+
jobParameter.setJobName(newJobName);
67+
assertEquals(newJobName, jobParameter.getName());
68+
69+
Instant newTime = Instant.now();
70+
jobParameter.setLastUpdateTime(newTime);
71+
assertEquals(newTime, jobParameter.getLastUpdateTime());
72+
73+
jobParameter.setEnabled(false);
74+
assertEquals(false, jobParameter.isEnabled());
75+
76+
Long newLockDuration = 30L;
77+
jobParameter.setLockDurationSeconds(newLockDuration);
78+
assertEquals(newLockDuration, jobParameter.getLockDurationSeconds());
79+
80+
Double newJitter = 0.7;
81+
jobParameter.setJitter(newJitter);
82+
assertEquals(newJitter, jobParameter.getJitter());
83+
}
84+
}

0 commit comments

Comments
 (0)