Skip to content

Commit 0920ba7

Browse files
authored
fine tune predict API: read model from index directly (opensearch-project#1557)
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 8c3e453 commit 0920ba7

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

+5-8
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,13 @@
1818
import java.util.Optional;
1919

2020
import org.opensearch.client.node.NodeClient;
21+
import org.opensearch.common.util.concurrent.ThreadContext;
2122
import org.opensearch.core.action.ActionListener;
2223
import org.opensearch.core.rest.RestStatus;
2324
import org.opensearch.core.xcontent.XContentParser;
2425
import org.opensearch.ml.common.FunctionName;
2526
import org.opensearch.ml.common.MLModel;
2627
import org.opensearch.ml.common.input.MLInput;
27-
import org.opensearch.ml.common.transport.model.MLModelGetAction;
28-
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
29-
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
3028
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3129
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
3230
import org.opensearch.ml.model.MLModelManager;
@@ -91,9 +89,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
9189
}
9290

9391
return channel -> {
94-
MLModelGetRequest getModelRequest = new MLModelGetRequest(modelId, false);
95-
ActionListener<MLModelGetResponse> listener = ActionListener.wrap(r -> {
96-
MLModel mlModel = r.getMlModel();
92+
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
9793
String algoName = mlModel.getAlgorithm().name();
9894
client
9995
.execute(
@@ -109,8 +105,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
109105
log.error("Failed to send error response", ex);
110106
}
111107
});
112-
client.execute(MLModelGetAction.INSTANCE, getModelRequest, listener);
113-
108+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
109+
modelManager.getModel(modelId, ActionListener.runBefore(listener, () -> context.restore()));
110+
}
114111
};
115112
}
116113

0 commit comments

Comments
 (0)