Skip to content

Commit 2c2dc11

Browse files
committed
add unit tests
Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
1 parent 607edbe commit 2c2dc11

12 files changed

+460
-133
lines changed

ml-algorithms/build.gradle

-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ dependencies {
7070
}
7171
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
7272

73-
74-
7573
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
7674
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
7775
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/S3Utils.java

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ public static void putObject(
5353
String key,
5454
String content
5555
) {
56-
5756
try (S3Client s3Client = initS3Client(accessKey, secretKey, sessionToken, region)) {
5857
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
5958
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();

plugin/build.gradle

-10
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ dependencies {
5959
implementation project(':opensearch-ml-memory')
6060
compileOnly "com.google.guava:guava:32.1.3-jre"
6161

62-
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
63-
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
64-
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
65-
66-
implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: '2.29.12'
67-
68-
implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: '2.29.12'
69-
70-
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: '2.29.12'
71-
7262
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
7363
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
7464
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

+25-10
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,22 @@
1111
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
1212
import static org.opensearch.ml.common.MLTask.REMOTE_JOB_FIELD;
1313
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
14-
import static org.opensearch.ml.common.MLTaskState.*;
15-
import static org.opensearch.ml.common.connector.AbstractConnector.*;
14+
import static org.opensearch.ml.common.MLTaskState.CANCELLED;
15+
import static org.opensearch.ml.common.MLTaskState.CANCELLING;
16+
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
17+
import static org.opensearch.ml.common.MLTaskState.EXPIRED;
18+
import static org.opensearch.ml.common.MLTaskState.FAILED;
19+
import static org.opensearch.ml.common.MLTaskState.UNREACHABLE;
20+
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
21+
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
22+
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
1623
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS;
17-
import static org.opensearch.ml.settings.MLCommonsSettings.*;
24+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
25+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
26+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
27+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
28+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FAILED_REGEX;
29+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
1830
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
1931
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
2032

@@ -27,6 +39,7 @@
2739
import java.util.regex.Matcher;
2840
import java.util.regex.Pattern;
2941

42+
import com.google.common.annotations.VisibleForTesting;
3043
import org.opensearch.ExceptionsHelper;
3144
import org.opensearch.OpenSearchException;
3245
import org.opensearch.OpenSearchStatusException;
@@ -440,6 +453,7 @@ private void executeConnector(
440453
}
441454

442455
decryptedCredential = connector.getDecryptedCredential();
456+
443457
if (decryptedCredential == null || decryptedCredential.isEmpty()) {
444458
decryptedCredential = mlEngine.getConnectorCredential(connector);
445459
}
@@ -449,7 +463,7 @@ private void executeConnector(
449463
connectorExecutor.setClient(client);
450464
connectorExecutor.setXContentRegistry(xContentRegistry);
451465
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
452-
processTaskResponse(mlTask, taskId, isUserInitiatedGetTaskRequest, taskResponse, remoteJob, actionListener);
466+
processTaskResponse(mlTask, taskId, isUserInitiatedGetTaskRequest, taskResponse, remoteJob ,actionListener);
453467
}, e -> {
454468
// When the request to remote service fails, we will retry the request for next 10 minutes (10 runs).
455469
// If it fails even then, we mark it as unreachable in task index and send message to DLQ
@@ -466,7 +480,7 @@ private void executeConnector(
466480
updatedTask.put(STATE_FIELD, UNREACHABLE);
467481
mlTask.setState(UNREACHABLE);
468482
mlTask.setError(e.getMessage());
469-
updateDLQ(mlTask);
483+
updateDLQ(mlTask, decryptedCredential);
470484
}
471485
updatedTask.put("remote_job", remoteJob);
472486
mlTaskManager.updateMLTaskDirectly(taskId, updatedTask);
@@ -504,7 +518,7 @@ protected void processTaskResponse(
504518

505519
mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> {
506520
if (mlTask.getState().equals(FAILED) && !isUserInitiatedGetTaskRequest) {
507-
updateDLQ(mlTask);
521+
updateDLQ(mlTask, decryptedCredential);
508522
}
509523
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
510524
}, e -> {
@@ -528,16 +542,17 @@ protected void processTaskResponse(
528542
}
529543
}
530544

531-
protected void updateDLQ(MLTask mlTask) {
545+
@VisibleForTesting
546+
protected void updateDLQ(MLTask mlTask, Map<String, String> decryptedCredential) {
532547
Map<String, Object> remoteJob = mlTask.getRemoteJob();
533548
Map<String, String> dlq = (Map<String, String>) remoteJob.get("dlq");
534549
if (dlq != null && !dlq.isEmpty()) {
535550
String taskId = mlTask.getTaskId();
536551
try {
537552
Map<String, Object> remoteJobDetails = mlTask.getRemoteJob();
538-
String accessKey = this.decryptedCredential.get(ACCESS_KEY_FIELD);
539-
String secretKey = this.decryptedCredential.get(SECRET_KEY_FIELD);
540-
String sessionToken = this.decryptedCredential.get(SESSION_TOKEN_FIELD);
553+
String accessKey = decryptedCredential.get(ACCESS_KEY_FIELD);
554+
String secretKey = decryptedCredential.get(SECRET_KEY_FIELD);
555+
String sessionToken = decryptedCredential.get(SESSION_TOKEN_FIELD);
541556

542557
String bucketName = dlq.get("bucket");
543558
String region = dlq.get("region");

plugin/src/main/java/org/opensearch/ml/jobs/BatchPredictTaskUpdateJob.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.jobs;
27

38
import java.io.IOException;

plugin/src/main/java/org/opensearch/ml/jobs/MLBatchPredictTaskUpdateJobParameter.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/*
22
* Copyright OpenSearch Contributors
33
* SPDX-License-Identifier: Apache-2.0
4-
*
5-
* The OpenSearch Contributors require contributions made to
6-
* this file be licensed under the Apache-2.0 license or a
7-
* compatible open source license.
84
*/
5+
96
package org.opensearch.ml.jobs;
107

118
import java.io.IOException;

plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunner.java

+8-11
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ public void setClient(Client client) {
6767
this.client = client;
6868
}
6969

70-
public void initialize(final ClusterService clusterService, final ThreadPool threadPool, final Client client) {
70+
public void initialize(
71+
final ClusterService clusterService,
72+
final ThreadPool threadPool,
73+
final Client client
74+
) {
7175
this.clusterService = clusterService;
7276
this.threadPool = threadPool;
7377
this.client = client;
@@ -97,12 +101,10 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont
97101
.boolQuery()
98102
.must(QueryBuilders.termQuery("task_type", MLTaskType.BATCH_PREDICTION))
99103
.must(QueryBuilders.termQuery("function_name", FunctionName.REMOTE))
100-
.must(
101-
QueryBuilders
104+
.must(QueryBuilders
102105
.boolQuery()
103106
.should(QueryBuilders.termQuery("state", MLTaskState.RUNNING))
104-
.should(QueryBuilders.termQuery("state", MLTaskState.CANCELLING))
105-
);
107+
.should(QueryBuilders.termQuery("state", MLTaskState.CANCELLING)));
106108

107109
sourceBuilder.query(boolQuery);
108110
sourceBuilder.size(100);
@@ -121,17 +123,12 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont
121123
for (SearchHit searchHit : searchHits) {
122124
String taskId = searchHit.getId();
123125
log.debug("Starting polling for task: {} at {}", taskId, Instant.now());
124-
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest
125-
.builder()
126-
.taskId(taskId)
127-
.isUserInitiatedGetTaskRequest(false)
128-
.build();
126+
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).isUserInitiatedGetTaskRequest(false).build();
129127

130128
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(taskResponse -> {
131129
log.info("Updated Task status for taskId: {} at {}", taskId, Instant.now());
132130
}, exception -> {
133131
log.error("Failed to get task status for task: " + taskId, exception);
134-
135132
}));
136133
}
137134
}, e -> {

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
import org.opensearch.core.xcontent.XContentParser;
4747
import org.opensearch.ml.breaker.MLCircuitBreakerService;
4848
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
49-
import org.opensearch.ml.common.*;
49+
import org.opensearch.ml.common.FunctionName;
50+
import org.opensearch.ml.common.MLModel;
51+
import org.opensearch.ml.common.MLTask;
52+
import org.opensearch.ml.common.MLTaskState;
53+
import org.opensearch.ml.common.MLTaskType;
5054
import org.opensearch.ml.common.connector.ConnectorAction;
5155
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
5256
import org.opensearch.ml.common.dataset.MLInputDataType;

0 commit comments

Comments
 (0)