Skip to content

Commit f45c18f

Browse files
committed
support batch task management by periocially bolling the remote task via a cron job
Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
1 parent 1659a60 commit f45c18f

23 files changed

+828
-112
lines changed

build.gradle

+3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ subprojects {
9595
// Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades.
9696
resolutionStrategy.force "com.google.guava:guava:32.1.3-jre"
9797
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
98+
resolutionStrategy.force "software.amazon.awssdk:aws-core:2.29.12"
99+
resolutionStrategy.force "software.amazon.awssdk:s3:2.29.12"
100+
resolutionStrategy.force "software.amazon.awssdk:regions:2.29.12"
98101
}
99102
}
100103

common/build.gradle

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ dependencies {
2323
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
2424
testImplementation "org.opensearch.test:framework:${opensearch_version}"
2525

26+
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
27+
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
28+
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
29+
2630
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
2731
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
2832
compileOnly group: 'org.json', name: 'json', version: '20231013'

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class CommonValue {
4444
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
4545
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
4646
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
47+
public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job";
4748
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
4849

4950
// Index mapping paths

common/src/main/java/org/opensearch/ml/common/MLTaskState.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ public enum MLTaskState {
3030
CANCELLED,
3131
COMPLETED_WITH_ERROR,
3232
CANCELLING,
33-
EXPIRED
33+
EXPIRED,
34+
UNREACHABLE
3435
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+25-2
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,28 @@
2525
@InputDataSet(MLInputDataType.REMOTE)
2626
public class RemoteInferenceInputDataSet extends MLInputDataset {
2727
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
28+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG = CommonValue.VERSION_2_19_0;
2829
@Setter
2930
private Map<String, String> parameters;
3031
@Setter
3132
private ActionType actionType;
33+
@Setter
34+
private Map<String, String> dlq;
3235

3336
@Builder(toBuilder = true)
34-
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
37+
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType, Map<String, String> dlq) {
3538
super(MLInputDataType.REMOTE);
3639
this.parameters = parameters;
3740
this.actionType = actionType;
41+
this.dlq = dlq;
42+
}
43+
44+
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
45+
this(parameters, actionType, null);
3846
}
3947

4048
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
41-
this(parameters, null);
49+
this(parameters, null, null);
4250
}
4351

4452
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
@@ -54,6 +62,13 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
5462
this.actionType = null;
5563
}
5664
}
65+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
66+
if (streamInput.readBoolean()) {
67+
dlq = streamInput.readMap(s -> s.readString(), s -> s.readString());
68+
} else {
69+
this.dlq = null;
70+
}
71+
}
5772
}
5873

5974
@Override
@@ -74,6 +89,14 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
7489
streamOutput.writeBoolean(false);
7590
}
7691
}
92+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
93+
if (dlq != null) {
94+
streamOutput.writeBoolean(true);
95+
streamOutput.writeMap(dlq, StreamOutput::writeString, StreamOutput::writeString);
96+
} else {
97+
streamOutput.writeBoolean(false);
98+
}
99+
}
77100
}
78101

79102
}

common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
public class RemoteInferenceMLInput extends MLInput {
2424
public static final String PARAMETERS_FIELD = "parameters";
2525
public static final String ACTION_TYPE_FIELD = "action_type";
26+
public static final String DLQ_FIELD = "dlq";
2627

2728
public RemoteInferenceMLInput(StreamInput in) throws IOException {
2829
super(in);
@@ -37,6 +38,7 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
3738
super();
3839
this.algorithm = functionName;
3940
Map<String, String> parameters = null;
41+
Map<String, String> dlq = null;
4042
ActionType actionType = null;
4143
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
4244
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -50,12 +52,15 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
5052
case ACTION_TYPE_FIELD:
5153
actionType = ActionType.from(parser.text());
5254
break;
55+
case DLQ_FIELD:
56+
dlq = StringUtils.getParameterMap(parser.map());
57+
break;
5358
default:
5459
parser.skipChildren();
5560
break;
5661
}
5762
}
58-
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
63+
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType, dlq);
5964
}
6065

6166
}

common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java

+15
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,34 @@
2727
public class MLTaskGetRequest extends ActionRequest {
2828
@Getter
2929
String taskId;
30+
3031
@Getter
3132
String tenantId;
3233

34+
// This is to identify if the get request is initiated by user or not. During batch task polling job,
35+
// we also perform get operation. This field is to distinguish between
36+
// these two situations.
37+
@Getter
38+
boolean isUserInitiatedGetTaskRequest;
39+
3340
@Builder
3441
public MLTaskGetRequest(String taskId, String tenantId) {
42+
this(taskId, tenantId, true);
43+
}
44+
45+
@Builder
46+
public MLTaskGetRequest(String taskId, String tenantId, Boolean isUserInitiatedGetTaskRequest) {
3547
this.taskId = taskId;
3648
this.tenantId = tenantId;
49+
this.isUserInitiatedGetTaskRequest = isUserInitiatedGetTaskRequest;
3750
}
3851

3952
public MLTaskGetRequest(StreamInput in) throws IOException {
4053
super(in);
4154
Version streamInputVersion = in.getVersion();
4255
this.taskId = in.readString();
4356
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
57+
this.isUserInitiatedGetTaskRequest = in.readBoolean();
4458
}
4559

4660
@Override
@@ -51,6 +65,7 @@ public void writeTo(StreamOutput out) throws IOException {
5165
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
5266
out.writeOptionalString(tenantId);
5367
}
68+
out.writeBoolean(isUserInitiatedGetTaskRequest);
5469
}
5570

5671
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.utils;
7+
8+
import static org.opensearch.ml.common.connector.AbstractConnector.*;
9+
10+
import java.security.AccessController;
11+
import java.security.PrivilegedActionException;
12+
import java.security.PrivilegedExceptionAction;
13+
14+
import com.google.common.annotations.VisibleForTesting;
15+
16+
import lombok.extern.log4j.Log4j2;
17+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
18+
import software.amazon.awssdk.auth.credentials.AwsCredentials;
19+
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
20+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
21+
import software.amazon.awssdk.core.sync.RequestBody;
22+
import software.amazon.awssdk.regions.Region;
23+
import software.amazon.awssdk.services.s3.S3Client;
24+
import software.amazon.awssdk.services.s3.model.*;
25+
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
26+
27+
@Log4j2
28+
public class S3Utils {
29+
@VisibleForTesting
30+
public static S3Client initS3Client(String accessKey, String secretKey, String sessionToken, String region) {
31+
AwsCredentials credentials = sessionToken == null
32+
? AwsBasicCredentials.create(accessKey, secretKey)
33+
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);
34+
35+
try {
36+
S3Client s3 = AccessController
37+
.doPrivileged(
38+
(PrivilegedExceptionAction<S3Client>) () -> S3Client
39+
.builder()
40+
.region(Region.of(region)) // Specify the region here
41+
.credentialsProvider(StaticCredentialsProvider.create(credentials))
42+
.build()
43+
);
44+
return s3;
45+
} catch (PrivilegedActionException e) {
46+
throw new RuntimeException("Can't load credentials", e);
47+
}
48+
}
49+
50+
public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
51+
try {
52+
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
53+
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();
54+
55+
s3Client.putObject(request, RequestBody.fromString(content));
56+
log.debug("Successfully uploaded file to S3: s3://{}/{}", bucketName, key);
57+
return null; // Void return type for doPrivileged
58+
});
59+
} catch (PrivilegedActionException e) {
60+
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
61+
}
62+
}
63+
64+
public static String getS3BucketName(String s3Uri) {
65+
// Remove the "s3://" prefix
66+
String uriWithoutPrefix = s3Uri.substring(5);
67+
// Find the first slash after the bucket name
68+
int slashIndex = uriWithoutPrefix.indexOf('/');
69+
// If there is no slash, the entire remaining string is the bucket name
70+
if (slashIndex == -1) {
71+
return uriWithoutPrefix;
72+
}
73+
// Otherwise, the bucket name is the substring up to the first slash
74+
return uriWithoutPrefix.substring(0, slashIndex);
75+
}
76+
77+
public static String getS3KeyName(String s3Uri) {
78+
String uriWithoutPrefix = s3Uri.substring(5);
79+
// Find the first slash after the bucket name
80+
int slashIndex = uriWithoutPrefix.indexOf('/');
81+
// If there is no slash, it means there is no key, return an empty string or handle as needed
82+
if (slashIndex == -1) {
83+
return "";
84+
}
85+
// The key name is the substring after the first slash
86+
return uriWithoutPrefix.substring(slashIndex + 1);
87+
}
88+
89+
}

ml-algorithms/build.gradle

+4-3
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ dependencies {
6969
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
7070
}
7171
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
72-
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
73-
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
74-
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
72+
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
73+
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
74+
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
75+
7576
implementation 'com.jayway.jsonpath:json-path:2.9.0'
7677
implementation group: 'org.json', name: 'json', version: '20231013'
7778
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.29.12'

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+10-64
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
package org.opensearch.ml.engine.ingest;
77

8-
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
9-
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
10-
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
8+
import static org.opensearch.ml.common.connector.AbstractConnector.*;
119
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
1210

1311
import java.io.BufferedReader;
@@ -27,17 +25,11 @@
2725
import org.opensearch.client.Client;
2826
import org.opensearch.core.rest.RestStatus;
2927
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
28+
import org.opensearch.ml.common.utils.S3Utils;
3029
import org.opensearch.ml.engine.annotation.Ingester;
3130

32-
import com.google.common.annotations.VisibleForTesting;
33-
3431
import lombok.extern.log4j.Log4j2;
35-
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
36-
import software.amazon.awssdk.auth.credentials.AwsCredentials;
37-
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
38-
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
3932
import software.amazon.awssdk.core.ResponseInputStream;
40-
import software.amazon.awssdk.regions.Region;
4133
import software.amazon.awssdk.services.s3.S3Client;
4234
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
4335
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
@@ -54,7 +46,12 @@ public S3DataIngestion(Client client) {
5446

5547
@Override
5648
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
57-
S3Client s3 = initS3Client(mlBatchIngestionInput);
49+
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
50+
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
51+
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
52+
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);
53+
54+
S3Client s3 = S3Utils.initS3Client(accessKey, secretKey, region, sessionToken);
5855

5956
List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
6057
if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) {
@@ -77,8 +74,8 @@ public double ingestSingleSource(
7774
boolean isSoleSource,
7875
int bulkSize
7976
) {
80-
String bucketName = getS3BucketName(s3Uri);
81-
String keyName = getS3KeyName(s3Uri);
77+
String bucketName = S3Utils.getS3BucketName(s3Uri);
78+
String keyName = S3Utils.getS3KeyName(s3Uri);
8279
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
8380
double successRate = 0;
8481

@@ -153,55 +150,4 @@ public double ingestSingleSource(
153150

154151
return successRate;
155152
}
156-
157-
private String getS3BucketName(String s3Uri) {
158-
// Remove the "s3://" prefix
159-
String uriWithoutPrefix = s3Uri.substring(5);
160-
// Find the first slash after the bucket name
161-
int slashIndex = uriWithoutPrefix.indexOf('/');
162-
// If there is no slash, the entire remaining string is the bucket name
163-
if (slashIndex == -1) {
164-
return uriWithoutPrefix;
165-
}
166-
// Otherwise, the bucket name is the substring up to the first slash
167-
return uriWithoutPrefix.substring(0, slashIndex);
168-
}
169-
170-
private String getS3KeyName(String s3Uri) {
171-
String uriWithoutPrefix = s3Uri.substring(5);
172-
// Find the first slash after the bucket name
173-
int slashIndex = uriWithoutPrefix.indexOf('/');
174-
// If there is no slash, it means there is no key, return an empty string or handle as needed
175-
if (slashIndex == -1) {
176-
return "";
177-
}
178-
// The key name is the substring after the first slash
179-
return uriWithoutPrefix.substring(slashIndex + 1);
180-
}
181-
182-
@VisibleForTesting
183-
public S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) {
184-
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
185-
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
186-
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
187-
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);
188-
189-
AwsCredentials credentials = sessionToken == null
190-
? AwsBasicCredentials.create(accessKey, secretKey)
191-
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);
192-
193-
try {
194-
S3Client s3 = AccessController
195-
.doPrivileged(
196-
(PrivilegedExceptionAction<S3Client>) () -> S3Client
197-
.builder()
198-
.region(Region.of(region)) // Specify the region here
199-
.credentialsProvider(StaticCredentialsProvider.create(credentials))
200-
.build()
201-
);
202-
return s3;
203-
} catch (PrivilegedActionException e) {
204-
throw new RuntimeException("Can't load credentials", e);
205-
}
206-
}
207153
}

0 commit comments

Comments
 (0)