Skip to content

Commit 224f8fc

Browse files
support get batch transform job status in get task API (opensearch-project#2825) (opensearch-project#2893)
* support get batch transform job status in get task API Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> * add cancel batch prediction job API for offline inference Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> * add unit tests and address comments Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> * stash context for get model Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> * apply spotlessJava and exclude from test coverage Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> --------- Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> (cherry picked from commit 8da7bd2) Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent a91843c commit 224f8fc

22 files changed

+1650
-18
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public class CommonValue {
6666
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
6767
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
6868
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
69-
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
69+
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3;
7070
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
7171
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
7272
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3;
@@ -393,6 +393,10 @@ public class CommonValue {
393393
+ "\" : {\"type\" : \"boolean\"}, \n"
394394
+ USER_FIELD_MAPPING
395395
+ " }\n"
396+
+ "}"
397+
+ MLTask.REMOTE_JOB_FIELD
398+
+ "\" : {\"type\": \"flat_object\"}\n"
399+
+ " }\n"
396400
+ "}";
397401

398402
public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n"

common/src/main/java/org/opensearch/ml/common/MLTask.java

+32-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
import java.util.ArrayList;
1414
import java.util.Arrays;
1515
import java.util.List;
16+
import java.util.Map;
1617

18+
import org.opensearch.Version;
1719
import org.opensearch.commons.authuser.User;
1820
import org.opensearch.core.common.io.stream.StreamInput;
1921
import org.opensearch.core.common.io.stream.StreamOutput;
@@ -45,6 +47,8 @@ public class MLTask implements ToXContentObject, Writeable {
4547
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
4648
public static final String ERROR_FIELD = "error";
4749
public static final String IS_ASYNC_TASK_FIELD = "is_async";
50+
public static final String REMOTE_JOB_FIELD = "remote_job";
51+
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB = CommonValue.VERSION_2_17_0;
4852

4953
@Setter
5054
private String taskId;
@@ -66,6 +70,8 @@ public class MLTask implements ToXContentObject, Writeable {
6670
private String error;
6771
private User user; // TODO: support document level access control later
6872
private boolean async;
73+
@Setter
74+
private Map<String, Object> remoteJob;
6975

7076
@Builder(toBuilder = true)
7177
public MLTask(
@@ -82,7 +88,8 @@ public MLTask(
8288
Instant lastUpdateTime,
8389
String error,
8490
User user,
85-
boolean async
91+
boolean async,
92+
Map<String, Object> remoteJob
8693
) {
8794
this.taskId = taskId;
8895
this.modelId = modelId;
@@ -98,9 +105,11 @@ public MLTask(
98105
this.error = error;
99106
this.user = user;
100107
this.async = async;
108+
this.remoteJob = remoteJob;
101109
}
102110

103111
public MLTask(StreamInput input) throws IOException {
112+
Version streamInputVersion = input.getVersion();
104113
this.taskId = input.readOptionalString();
105114
this.modelId = input.readOptionalString();
106115
this.taskType = input.readEnum(MLTaskType.class);
@@ -123,10 +132,16 @@ public MLTask(StreamInput input) throws IOException {
123132
this.user = null;
124133
}
125134
this.async = input.readBoolean();
135+
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) {
136+
if (input.readBoolean()) {
137+
this.remoteJob = input.readMap(s -> s.readString(), s -> s.readGenericValue());
138+
}
139+
}
126140
}
127141

128142
@Override
129143
public void writeTo(StreamOutput out) throws IOException {
144+
Version streamOutputVersion = out.getVersion();
130145
out.writeOptionalString(taskId);
131146
out.writeOptionalString(modelId);
132147
out.writeEnum(taskType);
@@ -150,6 +165,14 @@ public void writeTo(StreamOutput out) throws IOException {
150165
out.writeBoolean(false);
151166
}
152167
out.writeBoolean(async);
168+
if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB)) {
169+
if (remoteJob != null) {
170+
out.writeBoolean(true);
171+
out.writeMap(remoteJob, StreamOutput::writeString, StreamOutput::writeGenericValue);
172+
} else {
173+
out.writeBoolean(false);
174+
}
175+
}
153176
}
154177

155178
@Override
@@ -195,6 +218,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
195218
builder.field(USER, user);
196219
}
197220
builder.field(IS_ASYNC_TASK_FIELD, async);
221+
if (remoteJob != null) {
222+
builder.field(REMOTE_JOB_FIELD, remoteJob);
223+
}
198224
return builder.endObject();
199225
}
200226

@@ -218,6 +244,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
218244
String error = null;
219245
User user = null;
220246
boolean async = false;
247+
Map<String, Object> remoteJob = null;
221248

222249
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
223250
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -275,6 +302,9 @@ public static MLTask parse(XContentParser parser) throws IOException {
275302
case IS_ASYNC_TASK_FIELD:
276303
async = parser.booleanValue();
277304
break;
305+
case REMOTE_JOB_FIELD:
306+
remoteJob = parser.map();
307+
break;
278308
default:
279309
parser.skipChildren();
280310
break;
@@ -296,6 +326,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
296326
.error(error)
297327
.user(user)
298328
.async(async)
329+
.remoteJob(remoteJob)
299330
.build();
300331
}
301332
}

common/src/main/java/org/opensearch/ml/common/MLTaskType.java

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
public enum MLTaskType {
99
TRAINING,
1010
PREDICTION,
11+
BATCH_PREDICTION,
1112
TRAINING_AND_PREDICTION,
1213
EXECUTION,
1314
@Deprecated

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
188188
public enum ActionType {
189189
PREDICT,
190190
EXECUTE,
191-
BATCH_PREDICT;
191+
BATCH_PREDICT,
192+
CANCEL_BATCH_PREDICT,
193+
BATCH_PREDICT_STATUS;
192194

193195
public static ActionType from(String value) {
194196
try {

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ public ModelTensors(List<ModelTensor> mlModelTensors) {
3636
this.mlModelTensors = mlModelTensors;
3737
}
3838

39+
@Builder
40+
public ModelTensors(Integer statusCode) {
41+
this.statusCode = statusCode;
42+
}
43+
3944
@Override
4045
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
4146
builder.startObject();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.task;
7+
8+
import org.opensearch.action.ActionType;
9+
10+
public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> {
11+
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction();
12+
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job";
13+
14+
private MLCancelBatchJobAction() {
15+
super(NAME, MLCancelBatchJobResponse::new);
16+
}
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.task;
7+
8+
import static org.opensearch.action.ValidateActions.addValidationError;
9+
10+
import java.io.ByteArrayInputStream;
11+
import java.io.ByteArrayOutputStream;
12+
import java.io.IOException;
13+
import java.io.UncheckedIOException;
14+
15+
import org.opensearch.action.ActionRequest;
16+
import org.opensearch.action.ActionRequestValidationException;
17+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
18+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
19+
import org.opensearch.core.common.io.stream.StreamInput;
20+
import org.opensearch.core.common.io.stream.StreamOutput;
21+
22+
import lombok.Builder;
23+
import lombok.Getter;
24+
25+
public class MLCancelBatchJobRequest extends ActionRequest {
26+
@Getter
27+
String taskId;
28+
29+
@Builder
30+
public MLCancelBatchJobRequest(String taskId) {
31+
this.taskId = taskId;
32+
}
33+
34+
public MLCancelBatchJobRequest(StreamInput in) throws IOException {
35+
super(in);
36+
this.taskId = in.readString();
37+
}
38+
39+
@Override
40+
public void writeTo(StreamOutput out) throws IOException {
41+
super.writeTo(out);
42+
out.writeString(this.taskId);
43+
}
44+
45+
@Override
46+
public ActionRequestValidationException validate() {
47+
ActionRequestValidationException exception = null;
48+
49+
if (this.taskId == null) {
50+
exception = addValidationError("ML task id can't be null", exception);
51+
}
52+
53+
return exception;
54+
}
55+
56+
public static MLCancelBatchJobRequest fromActionRequest(ActionRequest actionRequest) {
57+
if (actionRequest instanceof MLCancelBatchJobRequest) {
58+
return (MLCancelBatchJobRequest) actionRequest;
59+
}
60+
61+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
62+
actionRequest.writeTo(osso);
63+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
64+
return new MLCancelBatchJobRequest(input);
65+
}
66+
} catch (IOException e) {
67+
throw new UncheckedIOException("failed to parse ActionRequest into MLCancelBatchJobRequest", e);
68+
}
69+
}
70+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.task;
7+
8+
import java.io.ByteArrayInputStream;
9+
import java.io.ByteArrayOutputStream;
10+
import java.io.IOException;
11+
import java.io.UncheckedIOException;
12+
13+
import org.opensearch.core.action.ActionResponse;
14+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
15+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
16+
import org.opensearch.core.common.io.stream.StreamInput;
17+
import org.opensearch.core.common.io.stream.StreamOutput;
18+
import org.opensearch.core.rest.RestStatus;
19+
import org.opensearch.core.xcontent.ToXContentObject;
20+
import org.opensearch.core.xcontent.XContentBuilder;
21+
22+
import lombok.Builder;
23+
import lombok.Getter;
24+
25+
@Getter
26+
public class MLCancelBatchJobResponse extends ActionResponse implements ToXContentObject {
27+
28+
RestStatus status;
29+
30+
@Builder
31+
public MLCancelBatchJobResponse(RestStatus status) {
32+
this.status = status;
33+
}
34+
35+
public MLCancelBatchJobResponse(StreamInput in) throws IOException {
36+
super(in);
37+
status = in.readEnum(RestStatus.class);
38+
}
39+
40+
@Override
41+
public void writeTo(StreamOutput out) throws IOException {
42+
out.writeEnum(status);
43+
}
44+
45+
public static MLCancelBatchJobResponse fromActionResponse(ActionResponse actionResponse) {
46+
if (actionResponse instanceof MLCancelBatchJobResponse) {
47+
return (MLCancelBatchJobResponse) actionResponse;
48+
}
49+
50+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
51+
actionResponse.writeTo(osso);
52+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
53+
return new MLCancelBatchJobResponse(input);
54+
}
55+
} catch (IOException e) {
56+
throw new UncheckedIOException("failed to parse ActionResponse into MLTaskGetResponse", e);
57+
}
58+
}
59+
60+
@Override
61+
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
62+
return xContentBuilder.startObject().field("status", status).endObject();
63+
}
64+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package org.opensearch.ml.common.transport.task;
2+
3+
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertNotSame;
5+
6+
import java.io.IOException;
7+
import java.io.UncheckedIOException;
8+
9+
import org.junit.Before;
10+
import org.junit.Test;
11+
import org.opensearch.action.ActionRequest;
12+
import org.opensearch.action.ActionRequestValidationException;
13+
import org.opensearch.common.io.stream.BytesStreamOutput;
14+
import org.opensearch.core.common.io.stream.StreamOutput;
15+
16+
public class MLCancelBatchJobRequestTest {
17+
private String taskId;
18+
19+
@Before
20+
public void setUp() {
21+
taskId = "test_id";
22+
}
23+
24+
@Test
25+
public void writeTo_Success() throws IOException {
26+
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build();
27+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
28+
mlCancelBatchJobRequest.writeTo(bytesStreamOutput);
29+
MLCancelBatchJobRequest parsedTask = new MLCancelBatchJobRequest(bytesStreamOutput.bytes().streamInput());
30+
assertEquals(parsedTask.getTaskId(), taskId);
31+
}
32+
33+
@Test
34+
public void validate_Exception_NullTaskId() {
35+
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().build();
36+
37+
ActionRequestValidationException exception = mlCancelBatchJobRequest.validate();
38+
assertEquals("Validation Failed: 1: ML task id can't be null;", exception.getMessage());
39+
}
40+
41+
@Test
42+
public void fromActionRequest_Success() {
43+
MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.builder().taskId(taskId).build();
44+
ActionRequest actionRequest = new ActionRequest() {
45+
@Override
46+
public ActionRequestValidationException validate() {
47+
return null;
48+
}
49+
50+
@Override
51+
public void writeTo(StreamOutput out) throws IOException {
52+
mlCancelBatchJobRequest.writeTo(out);
53+
}
54+
};
55+
MLCancelBatchJobRequest result = MLCancelBatchJobRequest.fromActionRequest(actionRequest);
56+
assertNotSame(result, mlCancelBatchJobRequest);
57+
assertEquals(result.getTaskId(), mlCancelBatchJobRequest.getTaskId());
58+
}
59+
60+
@Test(expected = UncheckedIOException.class)
61+
public void fromActionRequest_IOException() {
62+
ActionRequest actionRequest = new ActionRequest() {
63+
@Override
64+
public ActionRequestValidationException validate() {
65+
return null;
66+
}
67+
68+
@Override
69+
public void writeTo(StreamOutput out) throws IOException {
70+
throw new IOException("test");
71+
}
72+
};
73+
MLCancelBatchJobRequest.fromActionRequest(actionRequest);
74+
}
75+
}

0 commit comments

Comments
 (0)