Skip to content

Commit 2e616f0

Browse files
committed
add cancel batch prediction job API for offline inference
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 115a6f7 commit 2e616f0

File tree

13 files changed

+531
-49
lines changed

13 files changed

+531
-49
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ public class CommonValue {
394394
+ USER_FIELD_MAPPING
395395
+ " }\n"
396396
+ "}"
397-
+ MLTask.TRANSFORM_JOB_FIELD
397+
+ MLTask.REMOTE_JOB_FIELD
398398
+ "\" : {\"type\": \"flat_object\"}\n"
399399
+ " }\n"
400400
+ "}";

common/src/main/java/org/opensearch/ml/common/MLTask.java

+13-13
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class MLTask implements ToXContentObject, Writeable {
5151
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
5252
public static final String ERROR_FIELD = "error";
5353
public static final String IS_ASYNC_TASK_FIELD = "is_async";
54-
public static final String TRANSFORM_JOB_FIELD = "transform_job";
54+
public static final String REMOTE_JOB_FIELD = "remote_job";
5555
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB = CommonValue.VERSION_2_16_0;
5656

5757
@Setter
@@ -75,7 +75,7 @@ public class MLTask implements ToXContentObject, Writeable {
7575
private User user; // TODO: support document level access control later
7676
private boolean async;
7777
@Setter
78-
private Map<String, Object> transformJob;
78+
private Map<String, Object> remoteJob;
7979

8080
@Builder(toBuilder = true)
8181
public MLTask(
@@ -93,7 +93,7 @@ public MLTask(
9393
String error,
9494
User user,
9595
boolean async,
96-
Map<String, Object> transformJob
96+
Map<String, Object> remoteJob
9797
) {
9898
this.taskId = taskId;
9999
this.modelId = modelId;
@@ -109,7 +109,7 @@ public MLTask(
109109
this.error = error;
110110
this.user = user;
111111
this.async = async;
112-
this.transformJob = transformJob;
112+
this.remoteJob = remoteJob;
113113
}
114114

115115
public MLTask(StreamInput input) throws IOException {
@@ -139,7 +139,7 @@ public MLTask(StreamInput input) throws IOException {
139139
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
140140
if (input.readBoolean()) {
141141
String mapStr = input.readString();
142-
this.transformJob = gson.fromJson(mapStr, Map.class);
142+
this.remoteJob = gson.fromJson(mapStr, Map.class);
143143
}
144144
}
145145
}
@@ -171,11 +171,11 @@ public void writeTo(StreamOutput out) throws IOException {
171171
}
172172
out.writeBoolean(async);
173173
if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
174-
if (transformJob != null) {
174+
if (remoteJob != null) {
175175
out.writeBoolean(true);
176176
try {
177177
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
178-
out.writeString(gson.toJson(transformJob));
178+
out.writeString(gson.toJson(remoteJob));
179179
return null;
180180
});
181181
} catch (PrivilegedActionException e) {
@@ -230,8 +230,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
230230
builder.field(USER, user);
231231
}
232232
builder.field(IS_ASYNC_TASK_FIELD, async);
233-
if (transformJob != null) {
234-
builder.field(TRANSFORM_JOB_FIELD, transformJob);
233+
if (remoteJob != null) {
234+
builder.field(REMOTE_JOB_FIELD, remoteJob);
235235
}
236236
return builder.endObject();
237237
}
@@ -256,7 +256,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
256256
String error = null;
257257
User user = null;
258258
boolean async = false;
259-
Map<String, Object> transformJob = null;
259+
Map<String, Object> remoteJob = null;
260260

261261
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
262262
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -314,8 +314,8 @@ public static MLTask parse(XContentParser parser) throws IOException {
314314
case IS_ASYNC_TASK_FIELD:
315315
async = parser.booleanValue();
316316
break;
317-
case TRANSFORM_JOB_FIELD:
318-
transformJob = parser.map();
317+
case REMOTE_JOB_FIELD:
318+
remoteJob = parser.map();
319319
break;
320320
default:
321321
parser.skipChildren();
@@ -338,7 +338,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
338338
.error(error)
339339
.user(user)
340340
.async(async)
341-
.transformJob(transformJob)
341+
.remoteJob(remoteJob)
342342
.build();
343343
}
344344
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ public enum ActionType {
189189
PREDICT,
190190
EXECUTE,
191191
BATCH_PREDICT,
192-
CANCEL_BATCH,
193-
BATCH_STATUS;
192+
CANCEL_BATCH_PREDICT,
193+
BATCH_PREDICT_STATUS;
194194

195195
public static ActionType from(String value) {
196196
try {

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ public ModelTensors(List<ModelTensor> mlModelTensors) {
3636
this.mlModelTensors = mlModelTensors;
3737
}
3838

39+
@Builder
40+
public ModelTensors(Integer statusCode) {
41+
this.statusCode = statusCode;
42+
}
43+
3944
@Override
4045
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
4146
builder.startObject();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.task;
7+
8+
import org.opensearch.action.ActionType;
9+
10+
public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> {
11+
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction();
12+
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job";
13+
14+
private MLCancelBatchJobAction() {
15+
super(NAME, MLCancelBatchJobResponse::new);
16+
}
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.task;
7+
8+
import static org.opensearch.action.ValidateActions.addValidationError;
9+
10+
import java.io.ByteArrayInputStream;
11+
import java.io.ByteArrayOutputStream;
12+
import java.io.IOException;
13+
import java.io.UncheckedIOException;
14+
15+
import org.opensearch.action.ActionRequest;
16+
import org.opensearch.action.ActionRequestValidationException;
17+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
18+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
19+
import org.opensearch.core.common.io.stream.StreamInput;
20+
import org.opensearch.core.common.io.stream.StreamOutput;
21+
22+
import lombok.Builder;
23+
import lombok.Getter;
24+
25+
public class MLCancelBatchJobRequest extends ActionRequest {
26+
@Getter
27+
String taskId;
28+
29+
@Builder
30+
public MLCancelBatchJobRequest(String taskId) {
31+
this.taskId = taskId;
32+
}
33+
34+
public MLCancelBatchJobRequest(StreamInput in) throws IOException {
35+
super(in);
36+
this.taskId = in.readString();
37+
}
38+
39+
@Override
40+
public void writeTo(StreamOutput out) throws IOException {
41+
super.writeTo(out);
42+
out.writeString(this.taskId);
43+
}
44+
45+
@Override
46+
public ActionRequestValidationException validate() {
47+
ActionRequestValidationException exception = null;
48+
49+
if (this.taskId == null) {
50+
exception = addValidationError("ML task id can't be null", exception);
51+
}
52+
53+
return exception;
54+
}
55+
56+
public static MLCancelBatchJobRequest fromActionRequest(ActionRequest actionRequest) {
57+
if (actionRequest instanceof MLCancelBatchJobRequest) {
58+
return (MLCancelBatchJobRequest) actionRequest;
59+
}
60+
61+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
62+
actionRequest.writeTo(osso);
63+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
64+
return new MLCancelBatchJobRequest(input);
65+
}
66+
} catch (IOException e) {
67+
throw new UncheckedIOException("failed to parse ActionRequest into MLCancelBatchJobRequest", e);
68+
}
69+
}
70+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.task;
7+
8+
import java.io.ByteArrayInputStream;
9+
import java.io.ByteArrayOutputStream;
10+
import java.io.IOException;
11+
import java.io.UncheckedIOException;
12+
13+
import org.opensearch.core.action.ActionResponse;
14+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
15+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
16+
import org.opensearch.core.common.io.stream.StreamInput;
17+
import org.opensearch.core.common.io.stream.StreamOutput;
18+
import org.opensearch.core.rest.RestStatus;
19+
import org.opensearch.core.xcontent.ToXContentObject;
20+
import org.opensearch.core.xcontent.XContentBuilder;
21+
22+
import lombok.Builder;
23+
import lombok.Getter;
24+
25+
@Getter
26+
public class MLCancelBatchJobResponse extends ActionResponse implements ToXContentObject {
27+
28+
RestStatus status;
29+
30+
@Builder
31+
public MLCancelBatchJobResponse(RestStatus status) {
32+
this.status = status;
33+
}
34+
35+
public MLCancelBatchJobResponse(StreamInput in) throws IOException {
36+
super(in);
37+
status = in.readEnum(RestStatus.class);
38+
}
39+
40+
@Override
41+
public void writeTo(StreamOutput out) throws IOException {
42+
out.writeEnum(status);
43+
}
44+
45+
public static MLCancelBatchJobResponse fromActionResponse(ActionResponse actionResponse) {
46+
if (actionResponse instanceof MLCancelBatchJobResponse) {
47+
return (MLCancelBatchJobResponse) actionResponse;
48+
}
49+
50+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
51+
actionResponse.writeTo(osso);
52+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
53+
return new MLCancelBatchJobResponse(input);
54+
}
55+
} catch (IOException e) {
56+
throw new UncheckedIOException("failed to parse ActionResponse into MLTaskGetResponse", e);
57+
}
58+
}
59+
60+
@Override
61+
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
62+
return xContentBuilder.startObject().field("status", status).endObject();
63+
}
64+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.opensearch.ml.engine.algorithms.remote;
99

1010
import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
11+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
1112
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
1213

1314
import java.nio.ByteBuffer;
@@ -169,13 +170,14 @@ public void onComplete() {
169170
}
170171

171172
private void response() {
173+
String body = responseBody.toString();
174+
172175
if (exceptionHolder.get() != null) {
173176
actionListener.onFailure(exceptionHolder.get());
174177
return;
175178
}
176179

177-
String body = responseBody.toString();
178-
if (Strings.isBlank(body)) {
180+
if (Strings.isBlank(body) && !action.equals(CANCEL_BATCH_PREDICT.toString())) {
179181
log.error("Remote model response body is empty!");
180182
actionListener.onFailure(new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST));
181183
return;
@@ -187,6 +189,13 @@ private void response() {
187189
return;
188190
}
189191

192+
if (action.equals(CANCEL_BATCH_PREDICT.toString())) {
193+
ModelTensors tensors = ModelTensors.builder().statusCode(statusCode).build();
194+
tensors.setStatusCode(statusCode);
195+
actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors));
196+
return;
197+
}
198+
190199
try {
191200
ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard);
192201
tensors.setStatusCode(statusCode);

0 commit comments

Comments
 (0)