Skip to content

Commit 8162359

Browse files
committed
debugging
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 4e9e129 commit 8162359

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+5-7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
import lombok.experimental.FieldDefaults;
4646
import lombok.extern.log4j.Log4j2;
4747

48+
import java.io.IOException;
49+
4850
@Log4j2
4951
@FieldDefaults(level = AccessLevel.PRIVATE)
5052
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {
@@ -160,8 +162,6 @@ public void onResponse(MLModel mlModel) {
160162
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
161163
}
162164
} else {
163-
log.info("ModelId " + modelId + " is enabled for prediction in remote layer.");
164-
log.info(mlPredictionTaskRequest.getMlInput().toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
165165
validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
166166
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
167167
}
@@ -236,21 +236,19 @@ private void executePredict(
236236
);
237237
}
238238

239-
private void validateInputSchema(String modelId, MLInput mlInput) {
240-
log.info("Went to validate input schema");
241-
log.info(modelCacheHelper.getModelInterface(modelId).containsKey("input"));
242-
log.info(modelCacheHelper.getModelInterface(modelId).get("input"));
239+
private void validateInputSchema(String modelId, MLInput mlInput) throws IOException {
243240
if (modelCacheHelper.getModelInterface(modelId) != null && modelCacheHelper.getModelInterface(modelId).get("input") != null) {
244241
String inputSchemaString = modelCacheHelper.getModelInterface(modelId).get("input");
245242
log.info(inputSchemaString);
243+
log.info(mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
246244
try {
247245
MLNodeUtils
248246
.validateSchema(
249247
inputSchemaString,
250248
mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()
251249
);
252250
} catch (Exception e) {
253-
throw new IllegalArgumentException("Error validating input schema: " + e.getMessage());
251+
throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST);
254252
}
255253
}
256254
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+1
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,7 @@ public Map<String, TokenBucket> getUserRateLimiterMap(String modelId) {
14861486
* Set up model interface with model id.
14871487
*/
14881488
private void setupModelInterface(String modelId, Map<String, String> modelInterface) {
1489+
log.debug("Model interface for model: {} loaded into cache.", modelId);
14891490
if (modelInterface != null) {
14901491
modelCacheHelper.setModelInterface(modelId, modelInterface);
14911492
} else {

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -377,10 +377,10 @@ private void runPredict(
377377
}
378378
MLOutput output = mlEngine.predict(mlInput, mlModel);
379379
if (output instanceof MLPredictionOutput) {
380-
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
380+
log.info("ModelId " + modelId + " is completed as MLPredictionOutput.");
381381
}
382382
if (output instanceof ModelTensorOutput) {
383-
log.info("ModelId " + modelId + " is completed for prediction.");
383+
log.info("ModelId " + modelId + " is completed as ModelTensorOutput.");
384384
// pick the first output tensor to validate the schema
385385
if (((ModelTensorOutput) output).getMlModelOutputs() != null
386386
&& !((ModelTensorOutput) output).getMlModelOutputs().isEmpty()) {

0 commit comments

Comments
 (0)