Skip to content

Commit e537ea5

Browse files
committed
ml task management job changes
Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
1 parent 4617dc3 commit e537ea5

23 files changed

+970
-108
lines changed

build.gradle

+3
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ subprojects {
7474
// Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades.
7575
resolutionStrategy.force "com.google.guava:guava:32.1.3-jre"
7676
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
77+
resolutionStrategy.force "software.amazon.awssdk:aws-core:2.29.12"
78+
resolutionStrategy.force "software.amazon.awssdk:s3:2.29.12"
79+
resolutionStrategy.force "software.amazon.awssdk:regions:2.29.12"
7780
}
7881
}
7982

common/build.gradle

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ dependencies {
2323
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
2424
testImplementation "org.opensearch.test:framework:${opensearch_version}"
2525

26+
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
27+
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
28+
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
29+
2630
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
2731
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
2832
compileOnly group: 'org.json', name: 'json', version: '20231013'

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class CommonValue {
4444
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
4545
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
4646
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
47+
public static final String TASK_POLLING_JOB_INDEX_NAME = ".ml_commons_task_polling_job";
4748
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
4849

4950
// Index mapping paths

common/src/main/java/org/opensearch/ml/common/MLTaskState.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ public enum MLTaskState {
3030
CANCELLED,
3131
COMPLETED_WITH_ERROR,
3232
CANCELLING,
33-
EXPIRED
33+
EXPIRED,
34+
UNREACHABLE
3435
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,19 @@ public class RemoteInferenceInputDataSet extends MLInputDataset {
2929
private Map<String, String> parameters;
3030
@Setter
3131
private ActionType actionType;
32+
@Setter
33+
private Map<String, String> dlq;
3234

3335
@Builder(toBuilder = true)
34-
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
36+
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType, Map<String, String> dlq) {
3537
super(MLInputDataType.REMOTE);
3638
this.parameters = parameters;
3739
this.actionType = actionType;
40+
this.dlq = dlq;
3841
}
3942

4043
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
41-
this(parameters, null);
44+
this(parameters, null, null);
4245
}
4346

4447
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {

common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
public class RemoteInferenceMLInput extends MLInput {
2424
public static final String PARAMETERS_FIELD = "parameters";
2525
public static final String ACTION_TYPE_FIELD = "action_type";
26+
public static final String DLQ_FIELD = "dlq";
2627

2728
public RemoteInferenceMLInput(StreamInput in) throws IOException {
2829
super(in);
@@ -37,6 +38,7 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
3738
super();
3839
this.algorithm = functionName;
3940
Map<String, String> parameters = null;
41+
Map<String, String> dlq = null;
4042
ActionType actionType = null;
4143
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
4244
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -50,12 +52,15 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
5052
case ACTION_TYPE_FIELD:
5153
actionType = ActionType.from(parser.text());
5254
break;
55+
case DLQ_FIELD:
56+
dlq = StringUtils.getParameterMap(parser.map());
57+
break;
5358
default:
5459
parser.skipChildren();
5560
break;
5661
}
5762
}
58-
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
63+
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType, dlq);
5964
}
6065

6166
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.utils;
7+
8+
import static org.opensearch.ml.common.connector.AbstractConnector.*;
9+
10+
import java.security.AccessController;
11+
import java.security.PrivilegedActionException;
12+
import java.security.PrivilegedExceptionAction;
13+
14+
import com.google.common.annotations.VisibleForTesting;
15+
16+
import lombok.extern.log4j.Log4j2;
17+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
18+
import software.amazon.awssdk.auth.credentials.AwsCredentials;
19+
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
20+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
21+
import software.amazon.awssdk.core.sync.RequestBody;
22+
import software.amazon.awssdk.regions.Region;
23+
import software.amazon.awssdk.services.s3.S3Client;
24+
import software.amazon.awssdk.services.s3.model.*;
25+
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
26+
27+
@Log4j2
28+
public class S3Utils {
29+
@VisibleForTesting
30+
public static S3Client initS3Client(String accessKey, String secretKey, String sessionToken, String region) {
31+
AwsCredentials credentials = sessionToken == null
32+
? AwsBasicCredentials.create(accessKey, secretKey)
33+
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);
34+
35+
try {
36+
S3Client s3 = AccessController
37+
.doPrivileged(
38+
(PrivilegedExceptionAction<S3Client>) () -> S3Client
39+
.builder()
40+
.region(Region.of(region)) // Specify the region here
41+
.credentialsProvider(StaticCredentialsProvider.create(credentials))
42+
.build()
43+
);
44+
return s3;
45+
} catch (PrivilegedActionException e) {
46+
throw new RuntimeException("Can't load credentials", e);
47+
}
48+
}
49+
50+
/**
51+
* Constructor that takes a role ARN and assumes the role.
52+
*/
53+
// @VisibleForTesting
54+
// public static S3Client initS3Client(String roleArn, String region) {
55+
// try {
56+
// S3Client s3Client = AccessController.doPrivileged((PrivilegedExceptionAction<S3Client>) () -> {
57+
// // Use the default credentials provider to assume the role
58+
// StsClient stsClient = StsClient.builder().region(Region.of(region)).build();
59+
//
60+
// AssumeRoleRequest assumeRoleRequest = AssumeRoleRequest.builder().roleArn(roleArn).build();
61+
//
62+
// AssumeRoleResponse assumeRoleResponse = stsClient.assumeRole(assumeRoleRequest);
63+
// Credentials stsCredentials = assumeRoleResponse.credentials();
64+
//
65+
// AwsSessionCredentials sessionCredentials = AwsSessionCredentials
66+
// .create(stsCredentials.accessKeyId(), stsCredentials.secretAccessKey(), stsCredentials.sessionToken());
67+
//
68+
// return S3Client
69+
// .builder()
70+
// .region(Region.of(region))
71+
// .credentialsProvider(StaticCredentialsProvider.create(sessionCredentials))
72+
// .build();
73+
// });
74+
//
75+
// return s3Client;
76+
// } catch (PrivilegedActionException e) {
77+
// throw new RuntimeException("Failed to assume role and initialize S3 client", e);
78+
// }
79+
// }
80+
81+
public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
82+
try {
83+
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
84+
PutObjectRequest request = PutObjectRequest.builder()
85+
.bucket(bucketName)
86+
.key(key)
87+
.build();
88+
89+
s3Client.putObject(request, RequestBody.fromString(content));
90+
log.debug("Successfully uploaded file to S3: s3://{}/{}", bucketName, key);
91+
return null; // Void return type for doPrivileged
92+
});
93+
} catch (PrivilegedActionException e) {
94+
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
95+
}
96+
}
97+
98+
public static String getS3BucketName(String s3Uri) {
99+
// Remove the "s3://" prefix
100+
String uriWithoutPrefix = s3Uri.substring(5);
101+
// Find the first slash after the bucket name
102+
int slashIndex = uriWithoutPrefix.indexOf('/');
103+
// If there is no slash, the entire remaining string is the bucket name
104+
if (slashIndex == -1) {
105+
return uriWithoutPrefix;
106+
}
107+
// Otherwise, the bucket name is the substring up to the first slash
108+
return uriWithoutPrefix.substring(0, slashIndex);
109+
}
110+
111+
public static String getS3KeyName(String s3Uri) {
112+
String uriWithoutPrefix = s3Uri.substring(5);
113+
// Find the first slash after the bucket name
114+
int slashIndex = uriWithoutPrefix.indexOf('/');
115+
// If there is no slash, it means there is no key, return an empty string or handle as needed
116+
if (slashIndex == -1) {
117+
return "";
118+
}
119+
// The key name is the substring after the first slash
120+
return uriWithoutPrefix.substring(slashIndex + 1);
121+
}
122+
123+
}

ml-algorithms/build.gradle

+4-3
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ dependencies {
6969
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
7070
}
7171
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
72-
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
73-
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
74-
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
72+
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
73+
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
74+
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
75+
7576
implementation 'com.jayway.jsonpath:json-path:2.9.0'
7677
implementation group: 'org.json', name: 'json', version: '20231013'
7778
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.29.12'

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+12-66
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55

66
package org.opensearch.ml.engine.ingest;
77

8-
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
9-
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
10-
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
11-
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
12-
138
import java.io.BufferedReader;
149
import java.io.InputStreamReader;
1510
import java.nio.charset.StandardCharsets;
@@ -28,21 +23,18 @@
2823
import org.opensearch.core.rest.RestStatus;
2924
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
3025
import org.opensearch.ml.engine.annotation.Ingester;
31-
32-
import com.google.common.annotations.VisibleForTesting;
26+
import org.opensearch.ml.common.utils.S3Utils;
3327

3428
import lombok.extern.log4j.Log4j2;
35-
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
36-
import software.amazon.awssdk.auth.credentials.AwsCredentials;
37-
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
38-
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
3929
import software.amazon.awssdk.core.ResponseInputStream;
40-
import software.amazon.awssdk.regions.Region;
4130
import software.amazon.awssdk.services.s3.S3Client;
4231
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
4332
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
4433
import software.amazon.awssdk.services.s3.model.S3Exception;
4534

35+
import static org.opensearch.ml.common.connector.AbstractConnector.*;
36+
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
37+
4638
@Log4j2
4739
@Ingester("s3")
4840
public class S3DataIngestion extends AbstractIngestion {
@@ -54,7 +46,12 @@ public S3DataIngestion(Client client) {
5446

5547
@Override
5648
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
57-
S3Client s3 = initS3Client(mlBatchIngestionInput);
49+
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
50+
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
51+
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
52+
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);
53+
54+
S3Client s3 = S3Utils.initS3Client(accessKey, secretKey, region, sessionToken);
5855

5956
List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
6057
if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) {
@@ -77,8 +74,8 @@ public double ingestSingleSource(
7774
boolean isSoleSource,
7875
int bulkSize
7976
) {
80-
String bucketName = getS3BucketName(s3Uri);
81-
String keyName = getS3KeyName(s3Uri);
77+
String bucketName = S3Utils.getS3BucketName(s3Uri);
78+
String keyName = S3Utils.getS3KeyName(s3Uri);
8279
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
8380
double successRate = 0;
8481

@@ -153,55 +150,4 @@ public double ingestSingleSource(
153150

154151
return successRate;
155152
}
156-
157-
private String getS3BucketName(String s3Uri) {
158-
// Remove the "s3://" prefix
159-
String uriWithoutPrefix = s3Uri.substring(5);
160-
// Find the first slash after the bucket name
161-
int slashIndex = uriWithoutPrefix.indexOf('/');
162-
// If there is no slash, the entire remaining string is the bucket name
163-
if (slashIndex == -1) {
164-
return uriWithoutPrefix;
165-
}
166-
// Otherwise, the bucket name is the substring up to the first slash
167-
return uriWithoutPrefix.substring(0, slashIndex);
168-
}
169-
170-
private String getS3KeyName(String s3Uri) {
171-
String uriWithoutPrefix = s3Uri.substring(5);
172-
// Find the first slash after the bucket name
173-
int slashIndex = uriWithoutPrefix.indexOf('/');
174-
// If there is no slash, it means there is no key, return an empty string or handle as needed
175-
if (slashIndex == -1) {
176-
return "";
177-
}
178-
// The key name is the substring after the first slash
179-
return uriWithoutPrefix.substring(slashIndex + 1);
180-
}
181-
182-
@VisibleForTesting
183-
public S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) {
184-
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
185-
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
186-
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
187-
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);
188-
189-
AwsCredentials credentials = sessionToken == null
190-
? AwsBasicCredentials.create(accessKey, secretKey)
191-
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);
192-
193-
try {
194-
S3Client s3 = AccessController
195-
.doPrivileged(
196-
(PrivilegedExceptionAction<S3Client>) () -> S3Client
197-
.builder()
198-
.region(Region.of(region)) // Specify the region here
199-
.credentialsProvider(StaticCredentialsProvider.create(credentials))
200-
.build()
201-
);
202-
return s3;
203-
} catch (PrivilegedActionException e) {
204-
throw new RuntimeException("Can't load credentials", e);
205-
}
206-
}
207153
}

0 commit comments

Comments
 (0)