Skip to content

Commit 58d8fbb

Browse files
committed
address comments and resolve dependencies to avoid conflicts
Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
1 parent c7351ca commit 58d8fbb

File tree

11 files changed

+62
-73
lines changed

11 files changed

+62
-73
lines changed

build.gradle

-3
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,6 @@ subprojects {
9595
// Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades.
9696
resolutionStrategy.force "com.google.guava:guava:32.1.3-jre"
9797
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
98-
resolutionStrategy.force "software.amazon.awssdk:aws-core:2.29.12"
99-
resolutionStrategy.force "software.amazon.awssdk:s3:2.29.12"
100-
resolutionStrategy.force "software.amazon.awssdk:regions:2.29.12"
10198
}
10299
}
103100

common/build.gradle

-4
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ dependencies {
2323
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
2424
testImplementation "org.opensearch.test:framework:${opensearch_version}"
2525

26-
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
27-
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
28-
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
29-
3026
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
3127
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
3228
compileOnly group: 'org.json', name: 'json', version: '20231013'

ml-algorithms/build.gradle

+6-3
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,12 @@ dependencies {
6969
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
7070
}
7171
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
72-
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
73-
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
74-
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
72+
73+
74+
75+
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
76+
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
77+
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
7578

7679
implementation 'com.jayway.jsonpath:json-path:2.9.0'
7780
implementation group: 'org.json', name: 'json', version: '20231013'

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
package org.opensearch.ml.engine.ingest;
77

8-
import static org.opensearch.ml.common.connector.AbstractConnector.*;
8+
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
9+
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
10+
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
911
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
1012

1113
import java.io.BufferedReader;
@@ -25,8 +27,8 @@
2527
import org.opensearch.client.Client;
2628
import org.opensearch.core.rest.RestStatus;
2729
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
28-
import org.opensearch.ml.common.utils.S3Utils;
2930
import org.opensearch.ml.engine.annotation.Ingester;
31+
import org.opensearch.ml.engine.utils.S3Utils;
3032

3133
import lombok.extern.log4j.Log4j2;
3234
import software.amazon.awssdk.core.ResponseInputStream;

common/src/main/java/org/opensearch/ml/common/utils/S3Utils.java ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/S3Utils.java

+14-6
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.utils;
7-
8-
import static org.opensearch.ml.common.connector.AbstractConnector.*;
6+
package org.opensearch.ml.engine.utils;
97

108
import java.security.AccessController;
119
import java.security.PrivilegedActionException;
@@ -21,7 +19,6 @@
2119
import software.amazon.awssdk.core.sync.RequestBody;
2220
import software.amazon.awssdk.regions.Region;
2321
import software.amazon.awssdk.services.s3.S3Client;
24-
import software.amazon.awssdk.services.s3.model.*;
2522
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
2623

2724
@Log4j2
@@ -47,8 +44,17 @@ public static S3Client initS3Client(String accessKey, String secretKey, String s
4744
}
4845
}
4946

50-
public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
51-
try {
47+
public static void putObject(
48+
String accessKey,
49+
String secretKey,
50+
String sessionToken,
51+
String region,
52+
String bucketName,
53+
String key,
54+
String content
55+
) {
56+
57+
try (S3Client s3Client = initS3Client(accessKey, secretKey, sessionToken, region)) {
5258
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
5359
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();
5460

@@ -58,6 +64,8 @@ public static void putObject(S3Client s3Client, String bucketName, String key, S
5864
});
5965
} catch (PrivilegedActionException e) {
6066
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
67+
} catch (Exception e) {
68+
log.error("Unexpected error during S3 upload", e);
6169
}
6270
}
6371

plugin/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ compileTestJava {
138138
}
139139

140140
//TODO: check which one should be enabled
141-
licenseHeaders.enabled = false
141+
licenseHeaders.enabled = true
142142
testingConventions.enabled = false
143143
forbiddenApis.ignoreFailures = false
144144
dependencyLicenses.enabled = false

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

+16-30
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@
6262
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
6363
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
6464
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
65-
import org.opensearch.ml.common.utils.S3Utils;
6665
import org.opensearch.ml.engine.MLEngine;
6766
import org.opensearch.ml.engine.MLEngineClassLoader;
6867
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
6968
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
7069
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
70+
import org.opensearch.ml.engine.utils.S3Utils;
7171
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
7272
import org.opensearch.ml.helper.ModelAccessControlHelper;
7373
import org.opensearch.ml.model.MLModelManager;
@@ -85,8 +85,6 @@
8585
import org.opensearch.transport.TransportService;
8686

8787
import lombok.extern.log4j.Log4j2;
88-
import software.amazon.awssdk.services.s3.S3Client;
89-
import software.amazon.awssdk.services.s3.model.S3Exception;
9088

9189
@Log4j2
9290
public class GetTaskTransportAction extends HandledTransportAction<ActionRequest, MLTaskGetResponse> {
@@ -239,7 +237,6 @@ private void handleAsyncResponse(
239237
handleThrowable(throwable, taskId, actionListener);
240238
return;
241239
}
242-
243240
processResponse(response, taskId, isUserInitiatedGetTaskRequest, tenantId, actionListener);
244241
}
245242

@@ -531,7 +528,7 @@ protected void processTaskResponse(
531528
}
532529
}
533530

534-
private void updateDLQ(MLTask mlTask) {
531+
protected void updateDLQ(MLTask mlTask) {
535532
Map<String, Object> remoteJob = mlTask.getRemoteJob();
536533
Map<String, String> dlq = (Map<String, String>) remoteJob.get("dlq");
537534
if (dlq != null && !dlq.isEmpty()) {
@@ -542,33 +539,22 @@ private void updateDLQ(MLTask mlTask) {
542539
String secretKey = this.decryptedCredential.get(SECRET_KEY_FIELD);
543540
String sessionToken = this.decryptedCredential.get(SESSION_TOKEN_FIELD);
544541

545-
if (dlq != null) {
546-
String bucketName = dlq.get("bucket");
547-
String region = dlq.get("region");
542+
String bucketName = dlq.get("bucket");
543+
String region = dlq.get("region");
548544

549-
if (bucketName == null || region == null) {
550-
log.error("Failed to get the bucket name and region from batch predict request");
551-
}
552-
remoteJobDetails.remove("dlq");
553-
S3Client s3Client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region);
554-
try {
555-
556-
String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name"));
557-
String s3ObjectKey = "BatchJobFailure_" + jobName;
558-
String content = mlTask.getState().equals(UNREACHABLE)
559-
? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError())
560-
: remoteJobDetails.toString();
561-
562-
S3Utils.putObject(s3Client, bucketName, s3ObjectKey, content);
563-
log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now());
564-
} catch (S3Exception e) {
565-
log.error("S3 Exception: " + e.awsErrorDetails().errorMessage());
566-
} catch (Exception e) {
567-
e.printStackTrace();
568-
} finally {
569-
s3Client.close();
570-
}
545+
if (bucketName == null || region == null) {
546+
log.error("Failed to get the bucket name and region from batch predict request");
571547
}
548+
remoteJobDetails.remove("dlq");
549+
550+
String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name"));
551+
String s3ObjectKey = "BatchJobFailure_" + jobName;
552+
String content = mlTask.getState().equals(UNREACHABLE)
553+
? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError())
554+
: remoteJobDetails.toString();
555+
556+
S3Utils.putObject(accessKey, secretKey, sessionToken, region, bucketName, s3ObjectKey, content);
557+
log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now());
572558
} catch (Exception e) {
573559
log.error("Failed to update task status for task: " + taskId, e);
574560
}

plugin/src/main/java/org/opensearch/ml/jobs/MLBatchTaskUpdateJobRunner.java

+18-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import org.opensearch.jobscheduler.spi.ScheduledJobParameter;
1818
import org.opensearch.jobscheduler.spi.ScheduledJobRunner;
1919
import org.opensearch.jobscheduler.spi.utils.LockService;
20+
import org.opensearch.ml.common.FunctionName;
21+
import org.opensearch.ml.common.MLTaskState;
22+
import org.opensearch.ml.common.MLTaskType;
2023
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
2124
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
2225
import org.opensearch.ml.task.MLTaskManager;
@@ -64,16 +67,10 @@ public void setClient(Client client) {
6467
this.client = client;
6568
}
6669

67-
public void initialize(
68-
final ClusterService clusterService,
69-
final ThreadPool threadPool,
70-
final Client client,
71-
final MLTaskManager taskManager
72-
) {
70+
public void initialize(final ClusterService clusterService, final ThreadPool threadPool, final Client client) {
7371
this.clusterService = clusterService;
7472
this.threadPool = threadPool;
7573
this.client = client;
76-
this.taskManager = taskManager;
7774
this.initialized = true;
7875
}
7976

@@ -93,19 +90,19 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont
9390
String jobName = scheduledJobParameter.getName();
9491
log.info("Starting job execution for job ID: {} at {}", jobName, Instant.now());
9592

96-
if (taskManager == null) {
97-
log.error("TaskManager not initialized. Cannot run batch task polling job");
98-
return;
99-
}
100-
10193
log.debug("Running batch task polling job");
10294

10395
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
10496
BoolQueryBuilder boolQuery = QueryBuilders
10597
.boolQuery()
106-
.must(QueryBuilders.termQuery("task_type", "BATCH_PREDICTION"))
107-
.must(QueryBuilders.termQuery("function_name", "REMOTE"))
108-
.must(QueryBuilders.termQuery("state", "RUNNING"));
98+
.must(QueryBuilders.termQuery("task_type", MLTaskType.BATCH_PREDICTION))
99+
.must(QueryBuilders.termQuery("function_name", FunctionName.REMOTE))
100+
.must(
101+
QueryBuilders
102+
.boolQuery()
103+
.should(QueryBuilders.termQuery("state", MLTaskState.RUNNING))
104+
.should(QueryBuilders.termQuery("state", MLTaskState.CANCELLING))
105+
);
109106

110107
sourceBuilder.query(boolQuery);
111108
sourceBuilder.size(100);
@@ -124,14 +121,14 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont
124121
for (SearchHit searchHit : searchHits) {
125122
String taskId = searchHit.getId();
126123
log.debug("Starting polling for task: {} at {}", taskId, Instant.now());
127-
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).isUserInitiatedGetRequest(false).build();
124+
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest
125+
.builder()
126+
.taskId(taskId)
127+
.isUserInitiatedGetTaskRequest(false)
128+
.build();
128129

129130
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(taskResponse -> {
130-
try {
131-
log.info("Updated Task status for taskId: {} at {}", taskId, Instant.now());
132-
} catch (Exception e) {
133-
log.error("Failed to update task status for task: " + taskId, e);
134-
}
131+
log.info("Updated Task status for taskId: {} at {}", taskId, Instant.now());
135132
}, exception -> {
136133
log.error("Failed to get task status for task: " + taskId, exception);
137134

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ public Collection<Object> createComponents(
713713
.getClusterSettings()
714714
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it);
715715

716-
MLBatchTaskUpdateJobRunner.getJobRunnerInstance().initialize(clusterService, threadPool, client, mlTaskManager);
716+
MLBatchTaskUpdateJobRunner.getJobRunnerInstance().initialize(clusterService, threadPool, client);
717717

718718
return ImmutableList
719719
.of(

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public void dispatchTask(
156156
}
157157

158158
// TODO: check if we are able to input an object into the s3 bucket.
159-
// Or check permissions to DLQ write access
159+
// Or check permissions to DLQ write access
160160
}
161161
}
162162

plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ public void startTaskPollingJob() throws IOException {
546546
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
547547
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
548548

549-
client.index(indexRequest, ActionListener.wrap(r -> { log.info("Indexed ml task polling job successfully {}"); }, e -> {
549+
client.index(indexRequest, ActionListener.wrap(r -> { log.info("Indexed ml task polling job successfully"); }, e -> {
550550
log.error("Failed to index task polling job", e);
551551
}));
552552
}

0 commit comments

Comments
 (0)