Skip to content

Commit 600990c

Browse files
[Backport 2.x] support batch task management by periodically polling the remote task via a cron job (#3458)
* support batch task management by periodically bolling the remote task via a cron job (#3421) * support batch task management by periocially bolling the remote task via a cron job Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * address comments and resolve dependencies to avoid conflicts Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add unit tests Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * renamed files and added more tests Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> --------- Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> (cherry picked from commit 161d789) * fix failing BWC tests Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * fix missing path in failing BWC tests Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * fix failing BWC tests Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add missing braces Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add missing braces Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add missing braces Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add missing braces Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add missing braces Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add to yml file Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add to yml file Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * add to yml file Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * refactored code Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> --------- Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> Co-authored-by: Bhavana Goud Ramaram <rbhavna@amazon.com> (cherry picked from commit f083b7e)
1 parent 7552412 commit 600990c

27 files changed

+1544
-101
lines changed

.github/workflows/test_bwc.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ jobs:
3333
echo plugin_version $plugin_version
3434
./gradlew assemble
3535
echo "Creating ./plugin/src/test/resources/org/opensearch/ml/bwc..."
36-
mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc
36+
mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc/job-scheduler
37+
mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc/ml
3738
- name: Run MLCommons Backwards Compatibility Tests
3839
run: |
3940
echo "Running backwards compatibility tests ..."

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class CommonValue {
4444
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
4545
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
4646
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
47+
public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job";
4748
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
4849
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";
4950

common/src/main/java/org/opensearch/ml/common/MLTaskState.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ public enum MLTaskState {
3030
CANCELLED,
3131
COMPLETED_WITH_ERROR,
3232
CANCELLING,
33-
EXPIRED
33+
EXPIRED,
34+
UNREACHABLE
3435
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+25-2
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,28 @@
2626
@InputDataSet(MLInputDataType.REMOTE)
2727
public class RemoteInferenceInputDataSet extends MLInputDataset {
2828
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
29+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG = CommonValue.VERSION_2_19_0;
2930
@Setter
3031
private Map<String, String> parameters;
3132
@Setter
3233
private ActionType actionType;
34+
@Setter
35+
private Map<String, String> dlq;
3336

3437
@Builder(toBuilder = true)
35-
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
38+
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType, Map<String, String> dlq) {
3639
super(MLInputDataType.REMOTE);
3740
this.parameters = parameters;
3841
this.actionType = actionType;
42+
this.dlq = dlq;
43+
}
44+
45+
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
46+
this(parameters, actionType, null);
3947
}
4048

4149
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
42-
this(parameters, null);
50+
this(parameters, null, null);
4351
}
4452

4553
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
@@ -55,6 +63,13 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
5563
this.actionType = null;
5664
}
5765
}
66+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
67+
if (streamInput.readBoolean()) {
68+
dlq = streamInput.readMap(s -> s.readString(), s -> s.readString());
69+
} else {
70+
this.dlq = null;
71+
}
72+
}
5873
}
5974

6075
@Override
@@ -75,6 +90,14 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
7590
streamOutput.writeBoolean(false);
7691
}
7792
}
93+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
94+
if (dlq != null) {
95+
streamOutput.writeBoolean(true);
96+
streamOutput.writeMap(dlq, StreamOutput::writeString, StreamOutput::writeString);
97+
} else {
98+
streamOutput.writeBoolean(false);
99+
}
100+
}
78101
}
79102

80103
}

common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
public class RemoteInferenceMLInput extends MLInput {
2424
public static final String PARAMETERS_FIELD = "parameters";
2525
public static final String ACTION_TYPE_FIELD = "action_type";
26+
public static final String DLQ_FIELD = "dlq";
2627

2728
public RemoteInferenceMLInput(StreamInput in) throws IOException {
2829
super(in);
@@ -37,6 +38,7 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
3738
super();
3839
this.algorithm = functionName;
3940
Map<String, String> parameters = null;
41+
Map<String, String> dlq = null;
4042
ActionType actionType = null;
4143
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
4244
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -50,12 +52,15 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
5052
case ACTION_TYPE_FIELD:
5153
actionType = ActionType.from(parser.text());
5254
break;
55+
case DLQ_FIELD:
56+
dlq = StringUtils.getParameterMap(parser.map());
57+
break;
5358
default:
5459
parser.skipChildren();
5560
break;
5661
}
5762
}
58-
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
63+
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType, dlq);
5964
}
6065

6166
}

common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java

+15
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,34 @@
2727
public class MLTaskGetRequest extends ActionRequest {
2828
@Getter
2929
String taskId;
30+
3031
@Getter
3132
String tenantId;
3233

34+
// This is to identify if the get request is initiated by user or not. During batch task polling job,
35+
// we also perform get operation. This field is to distinguish between
36+
// these two situations.
37+
@Getter
38+
boolean isUserInitiatedGetTaskRequest;
39+
3340
@Builder
3441
public MLTaskGetRequest(String taskId, String tenantId) {
42+
this(taskId, tenantId, true);
43+
}
44+
45+
@Builder
46+
public MLTaskGetRequest(String taskId, String tenantId, Boolean isUserInitiatedGetTaskRequest) {
3547
this.taskId = taskId;
3648
this.tenantId = tenantId;
49+
this.isUserInitiatedGetTaskRequest = isUserInitiatedGetTaskRequest;
3750
}
3851

3952
public MLTaskGetRequest(StreamInput in) throws IOException {
4053
super(in);
4154
Version streamInputVersion = in.getVersion();
4255
this.taskId = in.readString();
4356
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
57+
this.isUserInitiatedGetTaskRequest = in.readBoolean();
4458
}
4559

4660
@Override
@@ -51,6 +65,7 @@ public void writeTo(StreamOutput out) throws IOException {
5165
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
5266
out.writeOptionalString(tenantId);
5367
}
68+
out.writeBoolean(isUserInitiatedGetTaskRequest);
5469
}
5570

5671
@Override

ml-algorithms/build.gradle

+5-3
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ dependencies {
7272
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
7373
}
7474
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
75-
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
76-
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
77-
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
75+
76+
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
77+
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
78+
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
79+
7880
implementation 'com.jayway.jsonpath:json-path:2.9.0'
7981
implementation group: 'org.json', name: 'json', version: '20231013'
8082
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.29.12'

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+9-61
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,10 @@
2828
import org.opensearch.core.rest.RestStatus;
2929
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
3030
import org.opensearch.ml.engine.annotation.Ingester;
31-
32-
import com.google.common.annotations.VisibleForTesting;
31+
import org.opensearch.ml.engine.utils.S3Utils;
3332

3433
import lombok.extern.log4j.Log4j2;
35-
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
36-
import software.amazon.awssdk.auth.credentials.AwsCredentials;
37-
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
38-
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
3934
import software.amazon.awssdk.core.ResponseInputStream;
40-
import software.amazon.awssdk.regions.Region;
4135
import software.amazon.awssdk.services.s3.S3Client;
4236
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
4337
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
@@ -54,7 +48,12 @@ public S3DataIngestion(Client client) {
5448

5549
@Override
5650
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
57-
S3Client s3 = initS3Client(mlBatchIngestionInput);
51+
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
52+
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
53+
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
54+
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);
55+
56+
S3Client s3 = S3Utils.initS3Client(accessKey, secretKey, region, sessionToken);
5857

5958
List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
6059
if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) {
@@ -77,8 +76,8 @@ public double ingestSingleSource(
7776
boolean isSoleSource,
7877
int bulkSize
7978
) {
80-
String bucketName = getS3BucketName(s3Uri);
81-
String keyName = getS3KeyName(s3Uri);
79+
String bucketName = S3Utils.getS3BucketName(s3Uri);
80+
String keyName = S3Utils.getS3KeyName(s3Uri);
8281
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
8382
double successRate = 0;
8483

@@ -153,55 +152,4 @@ public double ingestSingleSource(
153152

154153
return successRate;
155154
}
156-
157-
private String getS3BucketName(String s3Uri) {
158-
// Remove the "s3://" prefix
159-
String uriWithoutPrefix = s3Uri.substring(5);
160-
// Find the first slash after the bucket name
161-
int slashIndex = uriWithoutPrefix.indexOf('/');
162-
// If there is no slash, the entire remaining string is the bucket name
163-
if (slashIndex == -1) {
164-
return uriWithoutPrefix;
165-
}
166-
// Otherwise, the bucket name is the substring up to the first slash
167-
return uriWithoutPrefix.substring(0, slashIndex);
168-
}
169-
170-
private String getS3KeyName(String s3Uri) {
171-
String uriWithoutPrefix = s3Uri.substring(5);
172-
// Find the first slash after the bucket name
173-
int slashIndex = uriWithoutPrefix.indexOf('/');
174-
// If there is no slash, it means there is no key, return an empty string or handle as needed
175-
if (slashIndex == -1) {
176-
return "";
177-
}
178-
// The key name is the substring after the first slash
179-
return uriWithoutPrefix.substring(slashIndex + 1);
180-
}
181-
182-
@VisibleForTesting
183-
public S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) {
184-
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
185-
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
186-
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
187-
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);
188-
189-
AwsCredentials credentials = sessionToken == null
190-
? AwsBasicCredentials.create(accessKey, secretKey)
191-
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);
192-
193-
try {
194-
S3Client s3 = AccessController
195-
.doPrivileged(
196-
(PrivilegedExceptionAction<S3Client>) () -> S3Client
197-
.builder()
198-
.region(Region.of(region)) // Specify the region here
199-
.credentialsProvider(StaticCredentialsProvider.create(credentials))
200-
.build()
201-
);
202-
return s3;
203-
} catch (PrivilegedActionException e) {
204-
throw new RuntimeException("Can't load credentials", e);
205-
}
206-
}
207155
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.utils;
7+
8+
import java.security.AccessController;
9+
import java.security.PrivilegedActionException;
10+
import java.security.PrivilegedExceptionAction;
11+
12+
import com.google.common.annotations.VisibleForTesting;
13+
14+
import lombok.extern.log4j.Log4j2;
15+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
16+
import software.amazon.awssdk.auth.credentials.AwsCredentials;
17+
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
18+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
19+
import software.amazon.awssdk.core.sync.RequestBody;
20+
import software.amazon.awssdk.regions.Region;
21+
import software.amazon.awssdk.services.s3.S3Client;
22+
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
23+
24+
@Log4j2
25+
public class S3Utils {
26+
@VisibleForTesting
27+
public static S3Client initS3Client(String accessKey, String secretKey, String sessionToken, String region) {
28+
AwsCredentials credentials = sessionToken == null
29+
? AwsBasicCredentials.create(accessKey, secretKey)
30+
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);
31+
32+
try {
33+
S3Client s3 = AccessController
34+
.doPrivileged(
35+
(PrivilegedExceptionAction<S3Client>) () -> S3Client
36+
.builder()
37+
.region(Region.of(region)) // Specify the region here
38+
.credentialsProvider(StaticCredentialsProvider.create(credentials))
39+
.build()
40+
);
41+
return s3;
42+
} catch (PrivilegedActionException e) {
43+
throw new RuntimeException("Can't load credentials", e);
44+
}
45+
}
46+
47+
public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
48+
try {
49+
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
50+
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();
51+
52+
s3Client.putObject(request, RequestBody.fromString(content));
53+
log.debug("Successfully uploaded file to S3: s3://{}/{}", bucketName, key);
54+
return null; // Void return type for doPrivileged
55+
});
56+
} catch (PrivilegedActionException e) {
57+
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
58+
}
59+
}
60+
61+
public static String getS3BucketName(String s3Uri) {
62+
// Remove the "s3://" prefix
63+
String uriWithoutPrefix = s3Uri.substring(5);
64+
// Find the first slash after the bucket name
65+
int slashIndex = uriWithoutPrefix.indexOf('/');
66+
// If there is no slash, the entire remaining string is the bucket name
67+
if (slashIndex == -1) {
68+
return uriWithoutPrefix;
69+
}
70+
// Otherwise, the bucket name is the substring up to the first slash
71+
return uriWithoutPrefix.substring(0, slashIndex);
72+
}
73+
74+
public static String getS3KeyName(String s3Uri) {
75+
String uriWithoutPrefix = s3Uri.substring(5);
76+
// Find the first slash after the bucket name
77+
int slashIndex = uriWithoutPrefix.indexOf('/');
78+
// If there is no slash, it means there is no key, return an empty string or handle as needed
79+
if (slashIndex == -1) {
80+
return "";
81+
}
82+
// The key name is the substring after the first slash
83+
return uriWithoutPrefix.substring(slashIndex + 1);
84+
}
85+
86+
}

0 commit comments

Comments
 (0)