Skip to content

Commit 0d26931

Browse files
authored
add bedrock batch job post process function; enhance remote job status parsing (opensearch-project#2955)
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 091f5df commit 0d26931

File tree

12 files changed

+344
-30
lines changed

12 files changed

+344
-30
lines changed

common/src/main/java/org/opensearch/ml/common/MLTaskState.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,7 @@ public enum MLTaskState {
2828
COMPLETED,
2929
FAILED,
3030
CANCELLED,
31-
COMPLETED_WITH_ERROR
31+
COMPLETED_WITH_ERROR,
32+
CANCELLING,
33+
EXPIRED
3234
}

common/src/main/java/org/opensearch/ml/common/MLTaskType.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
public enum MLTaskType {
99
TRAINING,
1010
PREDICTION,
11-
BATCH_PREDICTION,
1211
TRAINING_AND_PREDICTION,
1312
EXECUTION,
1413
@Deprecated
@@ -17,5 +16,6 @@ public enum MLTaskType {
1716
LOAD_MODEL,
1817
REGISTER_MODEL,
1918
DEPLOY_MODEL,
20-
BATCH_INGEST
19+
BATCH_INGEST,
20+
BATCH_PREDICTION
2121
}

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

+5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Map;
1111
import java.util.function.Function;
1212

13+
import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction;
1314
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
1415
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
1516
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
@@ -20,6 +21,7 @@ public class MLPostProcessFunction {
2021
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
2122
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
2223
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
24+
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
2325
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
2426
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
2527
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
@@ -31,17 +33,20 @@ public class MLPostProcessFunction {
3133
static {
3234
EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction();
3335
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
36+
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
3437
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
3538
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
3639
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
3740
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
3841
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
42+
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
3943
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
4044
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
4145
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
4246
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
4347
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
4448
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
49+
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
4550
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
4651
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
4752
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
10+
import java.util.List;
11+
import java.util.Map;
12+
13+
import org.opensearch.ml.common.output.model.ModelTensor;
14+
15+
public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction<Map<String, String>> {
16+
public static final String JOB_ARN = "jobArn";
17+
public static final String PROCESSED_JOB_ARN = "processedJobArn";
18+
19+
@Override
20+
public void validate(Object input) {
21+
if (!(input instanceof Map)) {
22+
throw new IllegalArgumentException("Post process function input is not a Map.");
23+
}
24+
Map<String, String> jobInfo = (Map<String, String>) input;
25+
if (!(jobInfo.containsKey(JOB_ARN))) {
26+
throw new IllegalArgumentException("job arn is missing.");
27+
}
28+
}
29+
30+
@Override
31+
public List<ModelTensor> process(Map<String, String> jobInfo) {
32+
List<ModelTensor> modelTensors = new ArrayList<>();
33+
Map<String, String> processedResult = new HashMap<>();
34+
processedResult.putAll(jobInfo);
35+
String jobArn = jobInfo.get(JOB_ARN);
36+
processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F"));
37+
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(processedResult).build());
38+
return modelTensors;
39+
}
40+
}

common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> {
1111
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction();
12-
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job";
12+
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel";
1313

1414
private MLCancelBatchJobAction() {
1515
super(NAME, MLCancelBatchJobResponse::new);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction.JOB_ARN;
10+
import static org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction.PROCESSED_JOB_ARN;
11+
12+
import java.util.List;
13+
import java.util.Map;
14+
15+
import org.junit.Before;
16+
import org.junit.Rule;
17+
import org.junit.Test;
18+
import org.junit.rules.ExpectedException;
19+
import org.opensearch.ml.common.output.model.ModelTensor;
20+
21+
public class BedrockBatchJobArnPostProcessFunctionTest {
22+
23+
@Rule
24+
public ExpectedException exceptionRule = ExpectedException.none();
25+
26+
BedrockBatchJobArnPostProcessFunction function;
27+
28+
@Before
29+
public void setUp() {
30+
function = new BedrockBatchJobArnPostProcessFunction();
31+
}
32+
33+
@Test
34+
public void process_WrongInput_NotMap() {
35+
exceptionRule.expect(IllegalArgumentException.class);
36+
exceptionRule.expectMessage("Post process function input is not a Map.");
37+
function.apply("abc");
38+
}
39+
40+
@Test
41+
public void process_WrongInput_NotContainJobArn() {
42+
exceptionRule.expect(IllegalArgumentException.class);
43+
exceptionRule.expectMessage("job arn is missing.");
44+
function.apply(Map.of("test", "value"));
45+
}
46+
47+
@Test
48+
public void process_CorrectInput() {
49+
String jobArn = "arn:aws:bedrock:us-east-1:12345678912:model-invocation-job/w1xtlm0ik3e1";
50+
List<ModelTensor> result = function.apply(Map.of(JOB_ARN, jobArn));
51+
assertEquals(1, result.size());
52+
assertEquals(jobArn, result.get(0).getDataAsMap().get(JOB_ARN));
53+
assertEquals(
54+
"arn:aws:bedrock:us-east-1:12345678912:model-invocation-job%2Fw1xtlm0ik3e1",
55+
result.get(0).getDataAsMap().get(PROCESSED_JOB_ARN)
56+
);
57+
}
58+
}

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

+88-13
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,25 @@
1111
import static org.opensearch.ml.common.MLTask.REMOTE_JOB_FIELD;
1212
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1313
import static org.opensearch.ml.common.MLTaskState.CANCELLED;
14+
import static org.opensearch.ml.common.MLTaskState.CANCELLING;
1415
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
16+
import static org.opensearch.ml.common.MLTaskState.EXPIRED;
1517
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS;
18+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
19+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
20+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
21+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
22+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
1623
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
1724
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
1825

1926
import java.util.HashMap;
27+
import java.util.List;
2028
import java.util.Map;
2129
import java.util.Optional;
30+
import java.util.function.Consumer;
31+
import java.util.regex.Matcher;
32+
import java.util.regex.Pattern;
2233

2334
import org.opensearch.OpenSearchException;
2435
import org.opensearch.OpenSearchStatusException;
@@ -30,6 +41,8 @@
3041
import org.opensearch.client.Client;
3142
import org.opensearch.cluster.service.ClusterService;
3243
import org.opensearch.common.inject.Inject;
44+
import org.opensearch.common.settings.Setting;
45+
import org.opensearch.common.settings.Settings;
3346
import org.opensearch.common.util.concurrent.ThreadContext;
3447
import org.opensearch.core.action.ActionListener;
3548
import org.opensearch.core.rest.RestStatus;
@@ -80,6 +93,12 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
8093
MLTaskManager mlTaskManager;
8194
MLModelCacheHelper modelCacheHelper;
8295

96+
volatile List<String> remoteJobStatusFields;
97+
volatile Pattern remoteJobCompletedStatusRegexPattern;
98+
volatile Pattern remoteJobCancelledStatusRegexPattern;
99+
volatile Pattern remoteJobCancellingStatusRegexPattern;
100+
volatile Pattern remoteJobExpiredStatusRegexPattern;
101+
83102
@Inject
84103
public GetTaskTransportAction(
85104
TransportService transportService,
@@ -91,7 +110,8 @@ public GetTaskTransportAction(
91110
ConnectorAccessControlHelper connectorAccessControlHelper,
92111
EncryptorImpl encryptor,
93112
MLTaskManager mlTaskManager,
94-
MLModelManager mlModelManager
113+
MLModelManager mlModelManager,
114+
Settings settings
95115
) {
96116
super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new);
97117
this.client = client;
@@ -102,6 +122,44 @@ public GetTaskTransportAction(
102122
this.encryptor = encryptor;
103123
this.mlTaskManager = mlTaskManager;
104124
this.mlModelManager = mlModelManager;
125+
126+
remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings);
127+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it);
128+
initializeRegexPattern(
129+
ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX,
130+
settings,
131+
clusterService,
132+
(regex) -> remoteJobCompletedStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
133+
);
134+
initializeRegexPattern(
135+
ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
136+
settings,
137+
clusterService,
138+
(regex) -> remoteJobCancelledStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
139+
);
140+
initializeRegexPattern(
141+
ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
142+
settings,
143+
clusterService,
144+
(regex) -> remoteJobCancellingStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
145+
);
146+
initializeRegexPattern(
147+
ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
148+
settings,
149+
clusterService,
150+
(regex) -> remoteJobExpiredStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)
151+
);
152+
}
153+
154+
private void initializeRegexPattern(
155+
Setting<String> setting,
156+
Settings settings,
157+
ClusterService clusterService,
158+
Consumer<String> patternInitializer
159+
) {
160+
String regex = setting.get(settings);
161+
patternInitializer.accept(regex);
162+
clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> patternInitializer.accept(it));
105163
}
106164

107165
@Override
@@ -210,7 +268,7 @@ private void executeConnector(
210268
MLInput mlInput,
211269
String taskId,
212270
MLTask mlTask,
213-
Map<String, Object> transformJob,
271+
Map<String, Object> remoteJob,
214272
ActionListener<MLTaskGetResponse> actionListener
215273
) {
216274
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
@@ -222,15 +280,15 @@ private void executeConnector(
222280
connectorExecutor.setClient(client);
223281
connectorExecutor.setXContentRegistry(xContentRegistry);
224282
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
225-
processTaskResponse(mlTask, taskId, taskResponse, transformJob, actionListener);
283+
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
226284
}, e -> { actionListener.onFailure(e); }));
227285
} else {
228286
actionListener
229287
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
230288
}
231289
}
232290

233-
private void processTaskResponse(
291+
protected void processTaskResponse(
234292
MLTask mlTask,
235293
String taskId,
236294
MLTaskResponse taskResponse,
@@ -248,15 +306,11 @@ private void processTaskResponse(
248306
Map<String, Object> updatedTask = new HashMap<>();
249307
updatedTask.put(REMOTE_JOB_FIELD, remoteJob);
250308

251-
if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("completed"))
252-
|| (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Completed"))) {
253-
updatedTask.put(STATE_FIELD, COMPLETED);
254-
mlTask.setState(COMPLETED);
255-
256-
} else if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("cancelled"))
257-
|| (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Stopped"))) {
258-
updatedTask.put(STATE_FIELD, CANCELLED);
259-
mlTask.setState(CANCELLED);
309+
for (String statusField : remoteJobStatusFields) {
310+
String statusValue = String.valueOf(remoteJob.get(statusField));
311+
if (remoteJob.containsKey(statusField)) {
312+
updateTaskState(updatedTask, mlTask, statusValue);
313+
}
260314
}
261315
mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> {
262316
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
@@ -280,4 +334,25 @@ private void processTaskResponse(
280334
log.error("Unable to fetch status for ml task ", e);
281335
}
282336
}
337+
338+
private void updateTaskState(Map<String, Object> updatedTask, MLTask mlTask, String statusValue) {
339+
if (matchesPattern(remoteJobCancellingStatusRegexPattern, statusValue)) {
340+
updatedTask.put(STATE_FIELD, CANCELLING);
341+
mlTask.setState(CANCELLING);
342+
} else if (matchesPattern(remoteJobCancelledStatusRegexPattern, statusValue)) {
343+
updatedTask.put(STATE_FIELD, CANCELLED);
344+
mlTask.setState(CANCELLED);
345+
} else if (matchesPattern(remoteJobCompletedStatusRegexPattern, statusValue)) {
346+
updatedTask.put(STATE_FIELD, COMPLETED);
347+
mlTask.setState(COMPLETED);
348+
} else if (matchesPattern(remoteJobExpiredStatusRegexPattern, statusValue)) {
349+
updatedTask.put(STATE_FIELD, EXPIRED);
350+
mlTask.setState(EXPIRED);
351+
}
352+
}
353+
354+
private boolean matchesPattern(Pattern pattern, String input) {
355+
Matcher matcher = pattern.matcher(input);
356+
return matcher.find();
357+
}
283358
}

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,12 @@ public List<Setting<?>> getSettings() {
964964
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
965965
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
966966
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE,
967-
MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED
967+
MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED,
968+
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD,
969+
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX,
970+
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
971+
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
972+
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX
968973
);
969974
return settings;
970975
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java

+4-8
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
import com.google.common.annotations.VisibleForTesting;
2424
import com.google.common.collect.ImmutableList;
2525

26+
//TODO: Rename class and support cancelling more tasks. Now only support cancelling remote job
2627
public class RestMLCancelBatchJobAction extends BaseRestHandler {
27-
private static final String ML_CANCEL_BATCH_ACTION = "ml_cancel_batch_action";
28+
private static final String ML_CANCEL_TASK_ACTION = "ml_cancel_task_action";
2829

2930
/**
3031
* Constructor
@@ -33,18 +34,13 @@ public RestMLCancelBatchJobAction() {}
3334

3435
@Override
3536
public String getName() {
36-
return ML_CANCEL_BATCH_ACTION;
37+
return ML_CANCEL_TASK_ACTION;
3738
}
3839

3940
@Override
4041
public List<Route> routes() {
4142
return ImmutableList
42-
.of(
43-
new Route(
44-
RestRequest.Method.POST,
45-
String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel_batch", ML_BASE_URI, PARAMETER_TASK_ID)
46-
)
47-
);
43+
.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel", ML_BASE_URI, PARAMETER_TASK_ID)));
4844
}
4945

5046
@Override

0 commit comments

Comments
 (0)