Skip to content

Commit 489ee29

Browse files
committed
address more comments
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent a88775f commit 489ee29

File tree

9 files changed

+28
-19
lines changed

9 files changed

+28
-19
lines changed

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
187187
public enum ActionType {
188188
PREDICT,
189189
EXECUTE,
190-
BATCH;
190+
BATCH_PREDICT;
191191

192192
public static ActionType from(String value) {
193193
try {
@@ -199,7 +199,7 @@ public static ActionType from(String value) {
199199

200200
private static final HashSet<ActionType> MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of(
201201
PREDICT,
202-
BATCH
202+
BATCH_PREDICT
203203
));
204204

205205
public static boolean isValidActionInModelPrediction(ActionType actionType) {

common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,15 @@ public void parse_Remote_Model_With_ActionType() throws IOException {
186186
Map<String, String> parameters = Map.of("TransformJobName", "new name");
187187
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
188188
.parameters(parameters)
189-
.actionType(ConnectorAction.ActionType.BATCH)
189+
.actionType(ConnectorAction.ActionType.BATCH_PREDICT)
190190
.build();
191191

192-
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH\"}";
192+
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}";
193193

194-
testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH, expectedInputStr, parsedInput -> {
194+
testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> {
195195
assertNotNull(parsedInput.getInputDataset());
196196
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
197-
assertEquals(ConnectorAction.ActionType.BATCH, parsedInputDataSet.getActionType());
197+
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType());
198198
});
199199
}
200200

common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ public void constructor_stream() throws IOException {
3939
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
4040
Assert.assertEquals(1, inputDataSet.getParameters().size());
4141
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
42-
Assert.assertEquals("BATCH", inputDataSet.getActionType().toString());
42+
Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString());
4343
}
4444

4545
private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
46-
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"BATCH\" }";
46+
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }";
4747
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
4848
Collections.emptyList()).getNamedXContents()), null, jsonStr);
4949
parser.nextToken();

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ public List<Route> routes() {
7272
String.format(Locale.ROOT, "%s/_predict/{%s}/{%s}", ML_BASE_URI, PARAMETER_ALGORITHM, PARAMETER_MODEL_ID)
7373
),
7474
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_predict", ML_BASE_URI, PARAMETER_MODEL_ID)),
75-
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/models/{%s}/_batch", ML_BASE_URI, PARAMETER_MODEL_ID))
75+
new Route(
76+
RestRequest.Method.POST,
77+
String.format(Locale.ROOT, "%s/models/{%s}/_batch_predict", ML_BASE_URI, PARAMETER_MODEL_ID)
78+
)
7679
);
7780
}
7881

@@ -124,11 +127,13 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
124127
@VisibleForTesting
125128
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
126129
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
130+
System.out.println("actionType is " + actionType);
127131
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
128132
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
129133
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
130134
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
131135
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
136+
System.out.println(actionType.toString());
132137
throw new IllegalArgumentException("Wrong action type in the rest request path!");
133138
}
134139

plugin/src/main/java/org/opensearch/ml/stats/ActionName.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public enum ActionName {
1313
REGISTER,
1414
DEPLOY,
1515
UNDEPLOY,
16-
BATCH;
16+
BATCH_PREDICT;
1717

1818
public static ActionName from(String value) {
1919
try {

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-5
Original file line numberDiff line numberDiff line change
@@ -464,11 +464,7 @@ private ActionName getActionNameFromInput(MLInput mlInput) {
464464
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
465465
actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType();
466466
}
467-
if (actionType == null) {
468-
return ActionName.PREDICT;
469-
} else {
470-
return ActionName.from(actionType.toString());
471-
}
467+
return (actionType == null) ? ActionName.PREDICT : ActionName.from(actionType.toString());
472468
}
473469

474470
public void validateOutputSchema(String modelId, ModelTensorOutput output) {

plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,19 @@ public static void wrapListenerToHandleSearchIndexNotFound(Exception e, ActionLi
310310
}
311311
}
312312

313+
/**
314+
* Determine the ActionType from the restful request by checking the url path and method name so there's no need
315+
* to specify the ActionType in the request body. For example, /_plugins/_ml/models/{model_id}/_predict will return
316+
* PREDICT as the ActionType, and /_plugins/_ml/models/{model_id}/_batch_predict will return BATCH_PREDICT.
317+
* @param request A Restful request that needs to determine the ActionType from the path.
318+
* @return parsed user object
319+
*/
313320
public static String getActionTypeFromRestRequest(RestRequest request) {
314321
String path = request.path();
322+
System.out.println("path is " + path);
315323
String[] segments = path.split("/");
316324
String methodName = segments[segments.length - 1];
317-
methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName;
325+
methodName = methodName.startsWith("_") ? methodName.substring(1) : methodName;
318326

319327
// find the action type for "/_plugins/_ml/_predict/<algorithm>/<model_id>"
320328
if (!ActionType.isValidAction(methodName) && segments.length > 3) {

plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public void testRoutes_Batch() {
116116
assertFalse(routes.isEmpty());
117117
RestHandler.Route route = routes.get(2);
118118
assertEquals(RestRequest.Method.POST, route.getMethod());
119-
assertEquals("/_plugins/_ml/models/{model_id}/_batch", route.getPath());
119+
assertEquals("/_plugins/_ml/models/{model_id}/_batch_predict", route.getPath());
120120
}
121121

122122
public void testGetRequest() throws IOException {

plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ public static RestRequest getBatchRestRequest() {
225225
params.put(PARAMETER_ALGORITHM, "remote");
226226
final String requestContent = "{\"parameters\":{\"TransformJobName\":\"SM-offline-batch-transform-07-17-14-30\"}}";
227227
RestRequest request = new FakeRestRequest.Builder(getXContentRegistry())
228-
.withPath("/_plugins/_ml/models/{model_id}}/_batch")
228+
.withPath("/_plugins/_ml/models/{model_id}/_batch_predict")
229229
.withParams(params)
230230
.withContent(new BytesArray(requestContent), XContentType.JSON)
231231
.build();
@@ -388,7 +388,7 @@ public static void verifyParsedBatchMLInput(MLInput mlInput) {
388388
assertEquals(FunctionName.REMOTE, mlInput.getAlgorithm());
389389
assertEquals(MLInputDataType.REMOTE, mlInput.getInputDataset().getInputDataType());
390390
RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
391-
assertEquals(ConnectorAction.ActionType.BATCH, inputDataset.getActionType());
391+
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, inputDataset.getActionType());
392392
}
393393

394394
private static NamedXContentRegistry getXContentRegistry() {

0 commit comments

Comments
 (0)