Skip to content

Commit 9876730

Browse files
committed
debugging
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 8162359 commit 9876730

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,12 @@ private void executePredict(
236236
);
237237
}
238238

239-
private void validateInputSchema(String modelId, MLInput mlInput) throws IOException {
239+
private void validateInputSchema(String modelId, MLInput mlInput) {
240240
if (modelCacheHelper.getModelInterface(modelId) != null && modelCacheHelper.getModelInterface(modelId).get("input") != null) {
241241
String inputSchemaString = modelCacheHelper.getModelInterface(modelId).get("input");
242-
log.info(inputSchemaString);
243-
log.info(mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
244242
try {
243+
log.info(inputSchemaString);
244+
log.info(mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
245245
MLNodeUtils
246246
.validateSchema(
247247
inputSchemaString,

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+18-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL;
1515
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
1616

17+
import java.io.IOException;
1718
import java.time.Instant;
1819
import java.util.Arrays;
1920
import java.util.UUID;
@@ -330,7 +331,20 @@ private void runPredict(
330331
if (output instanceof MLPredictionOutput) {
331332
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
332333
}
333-
334+
if (output instanceof ModelTensorOutput) {
335+
log.info("ModelId " + modelId + " is completed as ModelTensorOutput.");
336+
// pick the first output tensor to validate the schema
337+
if (((ModelTensorOutput) output).getMlModelOutputs() != null
338+
&& !((ModelTensorOutput) output).getMlModelOutputs().isEmpty()) {
339+
ModelTensors modelTensors = ((ModelTensorOutput) output).getMlModelOutputs().get(0);
340+
log.info(modelTensors.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
341+
if (modelTensors.getMlModelTensors() != null && !modelTensors.getMlModelTensors().isEmpty()) {
342+
ModelTensor modelTensorOutput = modelTensors.getMlModelTensors().get(0);
343+
log.info(modelTensorOutput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
344+
validateOutputSchema(modelId, modelTensorOutput);
345+
}
346+
}
347+
}
334348
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
335349
handleAsyncMLTaskComplete(mlTask);
336350
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
@@ -377,7 +391,7 @@ private void runPredict(
377391
}
378392
MLOutput output = mlEngine.predict(mlInput, mlModel);
379393
if (output instanceof MLPredictionOutput) {
380-
log.info("ModelId " + modelId + " is completed as MLPredictionOutput.");
394+
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
381395
}
382396
if (output instanceof ModelTensorOutput) {
383397
log.info("ModelId " + modelId + " is completed as ModelTensorOutput.");
@@ -453,8 +467,9 @@ private void validateOutputSchema(String modelId, ModelTensor modelTensor) {
453467
log.info("Went to validate output schema");
454468
if (mlModelManager.getModelInterface(modelId) != null && mlModelManager.getModelInterface(modelId).get("output") != null) {
455469
String outputSchemaString = mlModelManager.getModelInterface(modelId).get("output");
456-
log.info(outputSchemaString);
457470
try {
471+
log.info(outputSchemaString);
472+
log.info(modelTensor.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
458473
MLNodeUtils
459474
.validateSchema(
460475
outputSchemaString,

0 commit comments

Comments
 (0)