From c10ffb3bf652ad7b3eb7b258f7fee84fa09880e4 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Fri, 28 Feb 2025 23:18:38 +0000 Subject: [PATCH] Handle when model not in cache Signed-off-by: Sicheng Song --- .../opensearch/ml/rest/RestMLPredictionAction.java | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index f16df20215..e0e028d9f0 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -88,12 +88,18 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client String modelId = getParameterId(request, PARAMETER_MODEL_ID); Optional functionName = modelManager.getOptionalModelFunctionName(modelId); - if (userAlgorithm != null && functionName.isPresent()) { - MLPredictionTaskRequest mlPredictionTaskRequest = getRequest(modelId, functionName.get().name(), userAlgorithm, request); - return channel -> client - .execute(MLPredictionTaskAction.INSTANCE, mlPredictionTaskRequest, new RestToXContentListener<>(channel)); + // check if the model is in cache + if (functionName.isPresent()) { + MLPredictionTaskRequest predictionRequest = getRequest( + modelId, + functionName.get().name(), + Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), + request + ); + return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); } + // If the model isn't in cache return channel -> { ActionListener listener = ActionListener.wrap(mlModel -> { String modelType = mlModel.getAlgorithm().name();