19
19
import java .util .Arrays ;
20
20
import java .util .UUID ;
21
21
22
+ import javax .swing .*;
23
+
22
24
import org .opensearch .OpenSearchException ;
23
25
import org .opensearch .OpenSearchStatusException ;
24
26
import org .opensearch .ResourceNotFoundException ;
47
49
import org .opensearch .ml .common .MLTask ;
48
50
import org .opensearch .ml .common .MLTaskState ;
49
51
import org .opensearch .ml .common .MLTaskType ;
52
+ import org .opensearch .ml .common .connector .ConnectorAction ;
50
53
import org .opensearch .ml .common .dataset .MLInputDataType ;
51
54
import org .opensearch .ml .common .dataset .MLInputDataset ;
55
+ import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
52
56
import org .opensearch .ml .common .input .MLInput ;
53
57
import org .opensearch .ml .common .output .MLOutput ;
54
58
import org .opensearch .ml .common .output .MLPredictionOutput ;
@@ -276,13 +280,12 @@ private String getPredictThreadPool(FunctionName functionName) {
276
280
private void predict (String modelId , MLTask mlTask , MLInput mlInput , ActionListener <MLTaskResponse > listener ) {
277
281
ActionListener <MLTaskResponse > internalListener = wrappedCleanupListener (listener , mlTask .getTaskId ());
278
282
// track ML task count and add ML task into cache
283
+ ActionName actionName = getActionNameFromInput (mlInput );
279
284
mlStats .getStat (MLNodeLevelStat .ML_EXECUTING_TASK_COUNT ).increment ();
280
285
mlStats .getStat (MLNodeLevelStat .ML_REQUEST_COUNT ).increment ();
281
- mlStats
282
- .createCounterStatIfAbsent (mlTask .getFunctionName (), ActionName .PREDICT , MLActionLevelStat .ML_ACTION_REQUEST_COUNT )
283
- .increment ();
286
+ mlStats .createCounterStatIfAbsent (mlTask .getFunctionName (), actionName , MLActionLevelStat .ML_ACTION_REQUEST_COUNT ).increment ();
284
287
if (modelId != null ) {
285
- mlStats .createModelCounterStatIfAbsent (modelId , ActionName . PREDICT , MLActionLevelStat .ML_ACTION_REQUEST_COUNT ).increment ();
288
+ mlStats .createModelCounterStatIfAbsent (modelId , actionName , MLActionLevelStat .ML_ACTION_REQUEST_COUNT ).increment ();
286
289
}
287
290
mlTask .setState (MLTaskState .RUNNING );
288
291
mlTaskManager .add (mlTask );
@@ -305,22 +308,23 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
305
308
.workerNodes (Arrays .asList (clusterService .localNode ().getId ()))
306
309
.build ();
307
310
mlModelManager .deployModel (modelId , null , functionName , false , true , mlDeployTask , ActionListener .wrap (s -> {
308
- runPredict (modelId , mlTask , mlInput , functionName , internalListener );
311
+ runPredict (modelId , mlTask , mlInput , functionName , actionName , internalListener );
309
312
}, e -> {
310
313
log .error ("Failed to auto deploy model " + modelId , e );
311
314
internalListener .onFailure (e );
312
315
}));
313
316
return ;
314
317
}
315
318
316
- runPredict (modelId , mlTask , mlInput , functionName , internalListener );
319
+ runPredict (modelId , mlTask , mlInput , functionName , actionName , internalListener );
317
320
}
318
321
319
322
private void runPredict (
320
323
String modelId ,
321
324
MLTask mlTask ,
322
325
MLInput mlInput ,
323
326
FunctionName algorithm ,
327
+ ActionName actionName ,
324
328
ActionListener <MLTaskResponse > internalListener
325
329
) {
326
330
// run predict
@@ -340,7 +344,7 @@ private void runPredict(
340
344
handleAsyncMLTaskComplete (mlTask );
341
345
mlModelManager .trackPredictDuration (modelId , startTime );
342
346
internalListener .onResponse (output );
343
- }, e -> handlePredictFailure (mlTask , internalListener , e , false , modelId ));
347
+ }, e -> handlePredictFailure (mlTask , internalListener , e , false , modelId , actionName ));
344
348
predictor .asyncPredict (mlInput , trackPredictDurationListener );
345
349
} else {
346
350
MLOutput output = mlModelManager .trackPredictDuration (modelId , () -> predictor .predict (mlInput ));
@@ -357,7 +361,7 @@ private void runPredict(
357
361
return ;
358
362
} catch (Exception e ) {
359
363
log .error ("Failed to predict model " + modelId , e );
360
- handlePredictFailure (mlTask , internalListener , e , false , modelId );
364
+ handlePredictFailure (mlTask , internalListener , e , false , modelId , actionName );
361
365
return ;
362
366
}
363
367
} else if (FunctionName .needDeployFirst (algorithm )) {
@@ -388,7 +392,7 @@ private void runPredict(
388
392
OpenSearchException e = new OpenSearchException (
389
393
"User: " + requestUser .getName () + " does not have permissions to run predict by model: " + modelId
390
394
);
391
- handlePredictFailure (mlTask , internalListener , e , false , modelId );
395
+ handlePredictFailure (mlTask , internalListener , e , false , modelId , actionName );
392
396
return ;
393
397
}
394
398
// run predict
@@ -413,7 +417,7 @@ private void runPredict(
413
417
414
418
}, e -> {
415
419
log .error ("Failed to predict " + mlInput .getAlgorithm () + ", modelId: " + mlTask .getModelId (), e );
416
- handlePredictFailure (mlTask , internalListener , e , true , modelId );
420
+ handlePredictFailure (mlTask , internalListener , e , true , modelId , actionName );
417
421
});
418
422
GetRequest getRequest = new GetRequest (ML_MODEL_INDEX , mlTask .getModelId ());
419
423
client
@@ -426,12 +430,12 @@ private void runPredict(
426
430
);
427
431
} catch (Exception e ) {
428
432
log .error ("Failed to get model " + mlTask .getModelId (), e );
429
- handlePredictFailure (mlTask , internalListener , e , true , modelId );
433
+ handlePredictFailure (mlTask , internalListener , e , true , modelId , actionName );
430
434
}
431
435
} else {
432
436
IllegalArgumentException e = new IllegalArgumentException ("ModelId is invalid" );
433
437
log .error ("ModelId is invalid" , e );
434
- handlePredictFailure (mlTask , internalListener , e , false , modelId );
438
+ handlePredictFailure (mlTask , internalListener , e , false , modelId , actionName );
435
439
}
436
440
}
437
441
@@ -445,19 +449,30 @@ private void handlePredictFailure(
445
449
ActionListener <MLTaskResponse > listener ,
446
450
Exception e ,
447
451
boolean trackFailure ,
448
- String modelId
452
+ String modelId ,
453
+ ActionName actionName
449
454
) {
450
455
if (trackFailure ) {
451
- mlStats
452
- .createCounterStatIfAbsent (mlTask .getFunctionName (), ActionName .PREDICT , MLActionLevelStat .ML_ACTION_FAILURE_COUNT )
453
- .increment ();
454
- mlStats .createModelCounterStatIfAbsent (modelId , ActionName .PREDICT , MLActionLevelStat .ML_ACTION_FAILURE_COUNT );
456
+ mlStats .createCounterStatIfAbsent (mlTask .getFunctionName (), actionName , MLActionLevelStat .ML_ACTION_FAILURE_COUNT ).increment ();
457
+ mlStats .createModelCounterStatIfAbsent (modelId , actionName , MLActionLevelStat .ML_ACTION_FAILURE_COUNT );
455
458
mlStats .getStat (MLNodeLevelStat .ML_FAILURE_COUNT ).increment ();
456
459
}
457
460
handleAsyncMLTaskFailure (mlTask , e );
458
461
listener .onFailure (e );
459
462
}
460
463
464
+ private ActionName getActionNameFromInput (MLInput mlInput ) {
465
+ ConnectorAction .ActionType actionType = null ;
466
+ if (mlInput .getInputDataset () instanceof RemoteInferenceInputDataSet ) {
467
+ actionType = ((RemoteInferenceInputDataSet ) mlInput .getInputDataset ()).getActionType ();
468
+ }
469
+ if (actionType == null ) {
470
+ return ActionName .PREDICT ;
471
+ } else {
472
+ return ActionName .from (actionType .toString ());
473
+ }
474
+ }
475
+
461
476
public void validateOutputSchema (String modelId , ModelTensorOutput output ) {
462
477
if (mlModelManager .getModelInterface (modelId ) != null && mlModelManager .getModelInterface (modelId ).get ("output" ) != null ) {
463
478
String outputSchemaString = mlModelManager .getModelInterface (modelId ).get ("output" );
0 commit comments