|
14 | 14 | import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL;
|
15 | 15 | import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
|
16 | 16 |
|
| 17 | +import java.io.IOException; |
17 | 18 | import java.time.Instant;
|
18 | 19 | import java.util.Arrays;
|
19 | 20 | import java.util.UUID;
|
@@ -330,7 +331,20 @@ private void runPredict(
|
330 | 331 | if (output instanceof MLPredictionOutput) {
|
331 | 332 | ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
|
332 | 333 | }
|
333 |
| - |
| 334 | + if (output instanceof ModelTensorOutput) { |
| 335 | + log.info("ModelId " + modelId + " is completed as ModelTensorOutput."); |
| 336 | + // pick the first output tensor to validate the schema |
| 337 | + if (((ModelTensorOutput) output).getMlModelOutputs() != null |
| 338 | + && !((ModelTensorOutput) output).getMlModelOutputs().isEmpty()) { |
| 339 | + ModelTensors modelTensors = ((ModelTensorOutput) output).getMlModelOutputs().get(0); |
| 340 | + log.info(modelTensors.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()); |
| 341 | + if (modelTensors.getMlModelTensors() != null && !modelTensors.getMlModelTensors().isEmpty()) { |
| 342 | + ModelTensor modelTensorOutput = modelTensors.getMlModelTensors().get(0); |
| 343 | + log.info(modelTensorOutput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()); |
| 344 | + validateOutputSchema(modelId, modelTensorOutput); |
| 345 | + } |
| 346 | + } |
| 347 | + } |
334 | 348 | // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
|
335 | 349 | handleAsyncMLTaskComplete(mlTask);
|
336 | 350 | MLTaskResponse response = MLTaskResponse.builder().output(output).build();
|
@@ -377,7 +391,7 @@ private void runPredict(
|
377 | 391 | }
|
378 | 392 | MLOutput output = mlEngine.predict(mlInput, mlModel);
|
379 | 393 | if (output instanceof MLPredictionOutput) {
|
380 |
| - log.info("ModelId " + modelId + " is completed as MLPredictionOutput."); |
| 394 | + ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); |
381 | 395 | }
|
382 | 396 | if (output instanceof ModelTensorOutput) {
|
383 | 397 | log.info("ModelId " + modelId + " is completed as ModelTensorOutput.");
|
@@ -453,8 +467,9 @@ private void validateOutputSchema(String modelId, ModelTensor modelTensor) {
|
453 | 467 | log.info("Went to validate output schema");
|
454 | 468 | if (mlModelManager.getModelInterface(modelId) != null && mlModelManager.getModelInterface(modelId).get("output") != null) {
|
455 | 469 | String outputSchemaString = mlModelManager.getModelInterface(modelId).get("output");
|
456 |
| - log.info(outputSchemaString); |
457 | 470 | try {
|
| 471 | + log.info(outputSchemaString); |
| 472 | + log.info(modelTensor.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()); |
458 | 473 | MLNodeUtils
|
459 | 474 | .validateSchema(
|
460 | 475 | outputSchemaString,
|
|
0 commit comments