Skip to content

Commit 0165b93

Browse files
committed
add batch prediction to task
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent dd3af2f commit 0165b93

File tree

8 files changed

+427
-13
lines changed

8 files changed

+427
-13
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public class CommonValue {
6565
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
6666
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
6767
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
68-
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
68+
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3;
6969
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
7070
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
7171
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3;
@@ -363,6 +363,10 @@ public class CommonValue {
363363
+ "\" : {\"type\" : \"boolean\"}, \n"
364364
+ USER_FIELD_MAPPING
365365
+ " }\n"
366+
+ "}"
367+
+ MLTask.TRANSFORM_JOB_FIELD
368+
+ "\" : {\"type\": \"flat_object\"}\n"
369+
+ " }\n"
366370
+ "}";
367371

368372
public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n"

common/src/main/java/org/opensearch/ml/common/MLTask.java

+46-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import lombok.EqualsAndHashCode;
1010
import lombok.Getter;
1111
import lombok.Setter;
12+
import org.opensearch.Version;
1213
import org.opensearch.core.common.io.stream.StreamInput;
1314
import org.opensearch.core.common.io.stream.StreamOutput;
1415
import org.opensearch.core.common.io.stream.Writeable;
@@ -17,15 +18,22 @@
1718
import org.opensearch.core.xcontent.XContentBuilder;
1819
import org.opensearch.core.xcontent.XContentParser;
1920
import org.opensearch.ml.common.dataset.MLInputDataType;
21+
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
2022

2123
import java.io.IOException;
24+
import java.security.AccessController;
25+
import java.security.PrivilegedActionException;
26+
import java.security.PrivilegedExceptionAction;
2227
import java.time.Instant;
2328
import java.util.ArrayList;
2429
import java.util.Arrays;
2530
import java.util.List;
31+
import java.util.Map;
2632

2733
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
2834
import static org.opensearch.ml.common.CommonValue.USER;
35+
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
36+
import static org.opensearch.ml.common.utils.StringUtils.gson;
2937

3038
@Getter
3139
@EqualsAndHashCode
@@ -44,6 +52,8 @@ public class MLTask implements ToXContentObject, Writeable {
4452
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
4553
public static final String ERROR_FIELD = "error";
4654
public static final String IS_ASYNC_TASK_FIELD = "is_async";
55+
public static final String TRANSFORM_JOB_FIELD = "transform_job";
56+
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB = CommonValue.VERSION_2_16_0;
4757

4858
@Setter
4959
private String taskId;
@@ -65,6 +75,8 @@ public class MLTask implements ToXContentObject, Writeable {
6575
private String error;
6676
private User user; // TODO: support document level access control later
6777
private boolean async;
78+
@Setter
79+
private Map<String, Object> transformJob;
6880

6981
@Builder(toBuilder = true)
7082
public MLTask(
@@ -81,7 +93,8 @@ public MLTask(
8193
Instant lastUpdateTime,
8294
String error,
8395
User user,
84-
boolean async
96+
boolean async,
97+
Map<String, Object> transformJob
8598
) {
8699
this.taskId = taskId;
87100
this.modelId = modelId;
@@ -97,9 +110,11 @@ public MLTask(
97110
this.error = error;
98111
this.user = user;
99112
this.async = async;
113+
this.transformJob = transformJob;
100114
}
101115

102116
public MLTask(StreamInput input) throws IOException {
117+
Version streamInputVersion = input.getVersion();
103118
this.taskId = input.readOptionalString();
104119
this.modelId = input.readOptionalString();
105120
this.taskType = input.readEnum(MLTaskType.class);
@@ -122,10 +137,17 @@ public MLTask(StreamInput input) throws IOException {
122137
this.user = null;
123138
}
124139
this.async = input.readBoolean();
140+
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
141+
if (input.readBoolean()) {
142+
String mapStr = input.readString();
143+
this.transformJob = gson.fromJson(mapStr, Map.class);
144+
}
145+
}
125146
}
126147

127148
@Override
128149
public void writeTo(StreamOutput out) throws IOException {
150+
Version streamOutputVersion = out.getVersion();
129151
out.writeOptionalString(taskId);
130152
out.writeOptionalString(modelId);
131153
out.writeEnum(taskType);
@@ -149,6 +171,21 @@ public void writeTo(StreamOutput out) throws IOException {
149171
out.writeBoolean(false);
150172
}
151173
out.writeBoolean(async);
174+
if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
175+
if (transformJob != null) {
176+
out.writeBoolean(true);
177+
try {
178+
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
179+
out.writeString(gson.toJson(transformJob));
180+
return null;
181+
});
182+
} catch (PrivilegedActionException e) {
183+
throw new RuntimeException(e);
184+
}
185+
} else {
186+
out.writeBoolean(false);
187+
}
188+
}
152189
}
153190

154191
@Override
@@ -194,6 +231,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
194231
builder.field(USER, user);
195232
}
196233
builder.field(IS_ASYNC_TASK_FIELD, async);
234+
if (transformJob != null) {
235+
builder.field(TRANSFORM_JOB_FIELD, transformJob);
236+
}
197237
return builder.endObject();
198238
}
199239

@@ -217,6 +257,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
217257
String error = null;
218258
User user = null;
219259
boolean async = false;
260+
Map<String, Object> transformJob = null;
220261

221262
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
222263
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -274,6 +315,9 @@ public static MLTask parse(XContentParser parser) throws IOException {
274315
case IS_ASYNC_TASK_FIELD:
275316
async = parser.booleanValue();
276317
break;
318+
case TRANSFORM_JOB_FIELD:
319+
transformJob = parser.map();
320+
break;
277321
default:
278322
parser.skipChildren();
279323
break;
@@ -294,6 +338,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
294338
.error(error)
295339
.user(user)
296340
.async(async)
341+
.transformJob(transformJob)
297342
.build();
298343
}
299344
}

common/src/main/java/org/opensearch/ml/common/MLTaskType.java

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
public enum MLTaskType {
99
TRAINING,
1010
PREDICTION,
11+
BATCH_PREDICTION,
1112
TRAINING_AND_PREDICTION,
1213
EXECUTION,
1314
@Deprecated

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
187187
public enum ActionType {
188188
PREDICT,
189189
EXECUTE,
190-
BATCH_PREDICT;
190+
BATCH_PREDICT,
191+
CANCEL_BATCH,
192+
BATCH_STATUS;
191193

192194
public static ActionType from(String value) {
193195
try {

0 commit comments

Comments
 (0)