Skip to content

Commit b98fe08

Browse files
committed
add stats for actionType and more UTs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent c859832 commit b98fe08

File tree

8 files changed

+193
-19
lines changed

8 files changed

+193
-19
lines changed

common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java

+15
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.HashMap;
2525
import java.util.Map;
2626

27+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;
28+
2729
public class ConnectorActionTest {
2830
@Rule
2931
public ExpectedException exceptionRule = ExpectedException.none();
@@ -140,4 +142,17 @@ public void parse() throws IOException {
140142
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
141143
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
142144
}
145+
146+
@Test
147+
public void test_wrongActionType() {
148+
exceptionRule.expect(IllegalArgumentException.class);
149+
exceptionRule.expectMessage("Wrong Action Type");
150+
ConnectorAction.ActionType.from("badAction");
151+
}
152+
153+
@Test
154+
public void test_invalidActionInModelPrediction() {
155+
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
156+
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
157+
}
143158
}

common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java

+21
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.junit.Test;
55
import org.opensearch.common.io.stream.BytesStreamOutput;
66
import org.opensearch.core.common.io.stream.StreamInput;
7+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
78
import org.opensearch.ml.common.dataset.MLInputDataset;
89

910
import java.io.IOException;
@@ -45,4 +46,24 @@ public void writeTo() throws IOException {
4546
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
4647
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
4748
}
49+
50+
@Test
51+
public void writeTo_withActionType() throws IOException {
52+
Map<String, String> parameters = new HashMap<>();
53+
parameters.put("key1", "test value1");
54+
parameters.put("key2", "test value2");
55+
ActionType actionType = ActionType.from("predict");
56+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).actionType(actionType).build();
57+
58+
BytesStreamOutput output = new BytesStreamOutput();
59+
inputDataSet.writeTo(output);
60+
StreamInput streamInput = output.bytes().streamInput();
61+
62+
RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
63+
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
64+
Assert.assertEquals(2, inputDataSet2.getParameters().size());
65+
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
66+
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
67+
Assert.assertEquals("PREDICT", inputDataSet2.getActionType().toString());
68+
}
4869
}

common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java

+56
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121
import org.opensearch.core.xcontent.XContentParser;
2222
import org.opensearch.index.query.MatchAllQueryBuilder;
2323
import org.opensearch.ml.common.FunctionName;
24+
import org.opensearch.ml.common.connector.ConnectorAction;
2425
import org.opensearch.ml.common.dataframe.*;
2526
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
2627
import org.opensearch.ml.common.dataset.MLInputDataset;
2728
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
2829
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
2930
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
31+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
3032
import org.opensearch.ml.common.input.nlp.TextSimilarityMLInput;
3133
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
3234
import org.opensearch.ml.common.output.model.ModelResultFilter;
@@ -37,7 +39,9 @@
3739
import java.util.ArrayList;
3840
import java.util.Arrays;
3941
import java.util.Collections;
42+
import java.util.HashMap;
4043
import java.util.List;
44+
import java.util.Map;
4145
import java.util.function.Consumer;
4246
import java.util.function.Function;
4347

@@ -160,6 +164,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException {
160164
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
161165
}
162166

167+
@Test
168+
public void parse_Remote_Model() throws IOException {
169+
Map<String, String> parameters = Map.of("TransformJobName", "new name");
170+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
171+
.parameters(parameters)
172+
.actionType(ConnectorAction.ActionType.PREDICT)
173+
.build();
174+
175+
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}";
176+
177+
testParse(FunctionName.REMOTE, remoteInferenceInputDataSet, expectedInputStr, parsedInput -> {
178+
assertNotNull(parsedInput.getInputDataset());
179+
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
180+
assertEquals(ConnectorAction.ActionType.PREDICT, parsedInputDataSet.getActionType());
181+
});
182+
}
183+
184+
@Test
185+
public void parse_Remote_Model_With_ActionType() throws IOException {
186+
Map<String, String> parameters = Map.of("TransformJobName", "new name");
187+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
188+
.parameters(parameters)
189+
.actionType(ConnectorAction.ActionType.BATCH)
190+
.build();
191+
192+
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH\"}";
193+
194+
testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH, expectedInputStr, parsedInput -> {
195+
assertNotNull(parsedInput.getInputDataset());
196+
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
197+
assertEquals(ConnectorAction.ActionType.BATCH, parsedInputDataSet.getActionType());
198+
});
199+
}
200+
163201
private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
164202
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
165203
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
@@ -178,6 +216,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri
178216
verify.accept(parsedInput);
179217
}
180218

219+
private void testParseWithActionType(FunctionName algorithm, MLInputDataset inputDataset, ConnectorAction.ActionType actionType, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
220+
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
221+
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
222+
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
223+
assertNotNull(builder);
224+
String jsonStr = builder.toString();
225+
assertEquals(expectedInputStr, jsonStr);
226+
227+
XContentParser parser = XContentType.JSON.xContent()
228+
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
229+
Collections.emptyList()).getNamedXContents()), null, jsonStr);
230+
parser.nextToken();
231+
MLInput parsedInput = MLInput.parse(parser, algorithm.name(), actionType);
232+
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
233+
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
234+
verify.accept(parsedInput);
235+
}
236+
181237
@Test
182238
public void readInputStream_Success() throws IOException {
183239
readInputStream(input, parsedInput -> {

common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ public void constructor_stream() throws IOException {
3939
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
4040
Assert.assertEquals(1, inputDataSet.getParameters().size());
4141
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
42+
Assert.assertEquals("BATCH", inputDataSet.getActionType().toString());
4243
}
4344

4445
private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
45-
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }";
46+
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"BATCH\" }";
4647
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
4748
Collections.emptyList()).getNamedXContents()), null, jsonStr);
4849
parser.nextToken();

plugin/src/main/java/org/opensearch/ml/stats/ActionName.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ public enum ActionName {
1212
EXECUTE,
1313
REGISTER,
1414
DEPLOY,
15-
UNDEPLOY;
15+
UNDEPLOY,
16+
BATCH;
1617

1718
public static ActionName from(String value) {
1819
try {

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+32-17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import java.util.Arrays;
2020
import java.util.UUID;
2121

22+
import javax.swing.*;
23+
2224
import org.opensearch.OpenSearchException;
2325
import org.opensearch.OpenSearchStatusException;
2426
import org.opensearch.ResourceNotFoundException;
@@ -47,8 +49,10 @@
4749
import org.opensearch.ml.common.MLTask;
4850
import org.opensearch.ml.common.MLTaskState;
4951
import org.opensearch.ml.common.MLTaskType;
52+
import org.opensearch.ml.common.connector.ConnectorAction;
5053
import org.opensearch.ml.common.dataset.MLInputDataType;
5154
import org.opensearch.ml.common.dataset.MLInputDataset;
55+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5256
import org.opensearch.ml.common.input.MLInput;
5357
import org.opensearch.ml.common.output.MLOutput;
5458
import org.opensearch.ml.common.output.MLPredictionOutput;
@@ -276,13 +280,12 @@ private String getPredictThreadPool(FunctionName functionName) {
276280
private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
277281
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
278282
// track ML task count and add ML task into cache
283+
ActionName actionName = getActionNameFromInput(mlInput);
279284
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
280285
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
281-
mlStats
282-
.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT)
283-
.increment();
286+
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), actionName, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
284287
if (modelId != null) {
285-
mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
288+
mlStats.createModelCounterStatIfAbsent(modelId, actionName, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
286289
}
287290
mlTask.setState(MLTaskState.RUNNING);
288291
mlTaskManager.add(mlTask);
@@ -305,22 +308,23 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
305308
.workerNodes(Arrays.asList(clusterService.localNode().getId()))
306309
.build();
307310
mlModelManager.deployModel(modelId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> {
308-
runPredict(modelId, mlTask, mlInput, functionName, internalListener);
311+
runPredict(modelId, mlTask, mlInput, functionName, actionName, internalListener);
309312
}, e -> {
310313
log.error("Failed to auto deploy model " + modelId, e);
311314
internalListener.onFailure(e);
312315
}));
313316
return;
314317
}
315318

316-
runPredict(modelId, mlTask, mlInput, functionName, internalListener);
319+
runPredict(modelId, mlTask, mlInput, functionName, actionName, internalListener);
317320
}
318321

319322
private void runPredict(
320323
String modelId,
321324
MLTask mlTask,
322325
MLInput mlInput,
323326
FunctionName algorithm,
327+
ActionName actionName,
324328
ActionListener<MLTaskResponse> internalListener
325329
) {
326330
// run predict
@@ -340,7 +344,7 @@ private void runPredict(
340344
handleAsyncMLTaskComplete(mlTask);
341345
mlModelManager.trackPredictDuration(modelId, startTime);
342346
internalListener.onResponse(output);
343-
}, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId));
347+
}, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName));
344348
predictor.asyncPredict(mlInput, trackPredictDurationListener);
345349
} else {
346350
MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
@@ -357,7 +361,7 @@ private void runPredict(
357361
return;
358362
} catch (Exception e) {
359363
log.error("Failed to predict model " + modelId, e);
360-
handlePredictFailure(mlTask, internalListener, e, false, modelId);
364+
handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName);
361365
return;
362366
}
363367
} else if (FunctionName.needDeployFirst(algorithm)) {
@@ -388,7 +392,7 @@ private void runPredict(
388392
OpenSearchException e = new OpenSearchException(
389393
"User: " + requestUser.getName() + " does not have permissions to run predict by model: " + modelId
390394
);
391-
handlePredictFailure(mlTask, internalListener, e, false, modelId);
395+
handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName);
392396
return;
393397
}
394398
// run predict
@@ -413,7 +417,7 @@ private void runPredict(
413417

414418
}, e -> {
415419
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
416-
handlePredictFailure(mlTask, internalListener, e, true, modelId);
420+
handlePredictFailure(mlTask, internalListener, e, true, modelId, actionName);
417421
});
418422
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
419423
client
@@ -426,12 +430,12 @@ private void runPredict(
426430
);
427431
} catch (Exception e) {
428432
log.error("Failed to get model " + mlTask.getModelId(), e);
429-
handlePredictFailure(mlTask, internalListener, e, true, modelId);
433+
handlePredictFailure(mlTask, internalListener, e, true, modelId, actionName);
430434
}
431435
} else {
432436
IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
433437
log.error("ModelId is invalid", e);
434-
handlePredictFailure(mlTask, internalListener, e, false, modelId);
438+
handlePredictFailure(mlTask, internalListener, e, false, modelId, actionName);
435439
}
436440
}
437441

@@ -445,19 +449,30 @@ private void handlePredictFailure(
445449
ActionListener<MLTaskResponse> listener,
446450
Exception e,
447451
boolean trackFailure,
448-
String modelId
452+
String modelId,
453+
ActionName actionName
449454
) {
450455
if (trackFailure) {
451-
mlStats
452-
.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT)
453-
.increment();
454-
mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT);
456+
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), actionName, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
457+
mlStats.createModelCounterStatIfAbsent(modelId, actionName, MLActionLevelStat.ML_ACTION_FAILURE_COUNT);
455458
mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
456459
}
457460
handleAsyncMLTaskFailure(mlTask, e);
458461
listener.onFailure(e);
459462
}
460463

464+
private ActionName getActionNameFromInput(MLInput mlInput) {
465+
ConnectorAction.ActionType actionType = null;
466+
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
467+
actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType();
468+
}
469+
if (actionType == null) {
470+
return ActionName.PREDICT;
471+
} else {
472+
return ActionName.from(actionType.toString());
473+
}
474+
}
475+
461476
public void validateOutputSchema(String modelId, ModelTensorOutput output) {
462477
if (mlModelManager.getModelInterface(modelId) != null && mlModelManager.getModelInterface(modelId).get("output") != null) {
463478
String outputSchemaString = mlModelManager.getModelInterface(modelId).get("output");

plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java

+29
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
1212
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1313
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
14+
import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest;
15+
import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest_WrongActionType;
1416
import static org.opensearch.ml.utils.TestHelper.getKMeansRestRequest;
17+
import static org.opensearch.ml.utils.TestHelper.verifyParsedBatchMLInput;
1518
import static org.opensearch.ml.utils.TestHelper.verifyParsedKMeansMLInput;
1619

1720
import java.io.IOException;
@@ -107,6 +110,15 @@ public void testRoutes() {
107110
assertEquals("/_plugins/_ml/_predict/{algorithm}/{model_id}", route.getPath());
108111
}
109112

113+
public void testRoutes_Batch() {
114+
List<RestHandler.Route> routes = restMLPredictionAction.routes();
115+
assertNotNull(routes);
116+
assertFalse(routes.isEmpty());
117+
RestHandler.Route route = routes.get(2);
118+
assertEquals(RestRequest.Method.POST, route.getMethod());
119+
assertEquals("/_plugins/_ml/models/{model_id}/_batch", route.getPath());
120+
}
121+
110122
public void testGetRequest() throws IOException {
111123
RestRequest request = getRestRequest_PredictModel();
112124
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);
@@ -143,6 +155,23 @@ public void testPrepareRequest() throws Exception {
143155
verifyParsedKMeansMLInput(mlInput);
144156
}
145157

158+
public void testPrepareBatchRequest() throws Exception {
159+
RestRequest request = getBatchRestRequest();
160+
restMLPredictionAction.handleRequest(request, channel, client);
161+
ArgumentCaptor<MLPredictionTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
162+
verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argumentCaptor.capture(), any());
163+
MLInput mlInput = argumentCaptor.getValue().getMlInput();
164+
verifyParsedBatchMLInput(mlInput);
165+
}
166+
167+
public void testPrepareBatchRequest_WrongActionType() throws Exception {
168+
thrown.expect(IllegalArgumentException.class);
169+
thrown.expectMessage("Wrong Action Type");
170+
171+
RestRequest request = getBatchRestRequest_WrongActionType();
172+
restMLPredictionAction.getRequest("model id", "remote", request);
173+
}
174+
146175
@Ignore
147176
public void testPrepareRequest_EmptyAlgorithm() throws Exception {
148177
MLModel model = MLModel.builder().algorithm(FunctionName.BATCH_RCF).build();

0 commit comments

Comments
 (0)