Skip to content

Commit 800dd13

Browse files
committed
deploy remote model if predicting an un_deployed remote model
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 0755e50 commit 800dd13

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+44
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
import org.opensearch.core.xcontent.NamedXContentRegistry;
2020
import org.opensearch.ml.common.FunctionName;
2121
import org.opensearch.ml.common.MLModel;
22+
import org.opensearch.ml.common.MLTaskState;
2223
import org.opensearch.ml.common.exception.MLValidationException;
2324
import org.opensearch.ml.common.transport.MLTaskResponse;
25+
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
26+
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
27+
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
2428
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
2529
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
2630
import org.opensearch.ml.helper.ModelAccessControlHelper;
@@ -132,6 +136,8 @@ public void onResponse(MLModel mlModel) {
132136
} else {
133137
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
134138
}
139+
} else if (functionName == FunctionName.REMOTE) {
140+
deployRemoteModel(modelId, functionName, wrappedListener, mlPredictionTaskRequest);
135141
} else {
136142
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
137143
}
@@ -159,6 +165,44 @@ public void onFailure(Exception e) {
159165
}
160166
}
161167

168+
private void deployRemoteModel(
169+
String modelId,
170+
FunctionName functionName,
171+
ActionListener<MLTaskResponse> wrappedListener,
172+
MLPredictionTaskRequest mlPredictionTaskRequest
173+
) {
174+
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);
175+
if (workerNodes != null && workerNodes.length != 0) {
176+
return;
177+
}
178+
179+
MLDeployModelRequest deployModelRequest = MLDeployModelRequest.builder().modelId(modelId).async(false).dispatchTask(true).build();
180+
ActionListener<MLDeployModelResponse> deployModelActionListener = ActionListener.wrap(deployModelResponse -> {
181+
// Deployment failed, existing
182+
if (!deployModelResponse.getStatus().equals(MLTaskState.COMPLETED.name())) {
183+
wrappedListener
184+
.onFailure(
185+
new IllegalArgumentException(
186+
"Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + modelId + "/_deploy"
187+
)
188+
);
189+
return;
190+
}
191+
// The DeployModel is async, set this maximum wait time for deployment to finish
192+
long startTime = System.currentTimeMillis();
193+
long maxDuration = 100; // 1 second in milliseconds
194+
while (workerNodes == null || workerNodes.length == 0) {
195+
long currentTime = System.currentTimeMillis();
196+
if (currentTime - startTime >= maxDuration) {
197+
log.info("Wait Time limit reached. Exiting loop.");
198+
break;
199+
}
200+
}
201+
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
202+
}, wrappedListener::onFailure);
203+
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, deployModelActionListener);
204+
}
205+
162206
private void executePredict(
163207
MLPredictionTaskRequest mlPredictionTaskRequest,
164208
ActionListener<MLTaskResponse> wrappedListener,

0 commit comments

Comments
 (0)