|
19 | 19 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
20 | 20 | import org.opensearch.ml.common.FunctionName;
|
21 | 21 | import org.opensearch.ml.common.MLModel;
|
| 22 | +import org.opensearch.ml.common.MLTaskState; |
22 | 23 | import org.opensearch.ml.common.exception.MLValidationException;
|
23 | 24 | import org.opensearch.ml.common.transport.MLTaskResponse;
|
| 25 | +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; |
| 26 | +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; |
| 27 | +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; |
24 | 28 | import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
|
25 | 29 | import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
|
26 | 30 | import org.opensearch.ml.helper.ModelAccessControlHelper;
|
@@ -132,6 +136,8 @@ public void onResponse(MLModel mlModel) {
|
132 | 136 | } else {
|
133 | 137 | executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
|
134 | 138 | }
|
| 139 | + } else if (functionName == FunctionName.REMOTE) { |
| 140 | + deployRemoteModel(modelId, functionName, wrappedListener, mlPredictionTaskRequest); |
135 | 141 | } else {
|
136 | 142 | executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
|
137 | 143 | }
|
@@ -159,6 +165,44 @@ public void onFailure(Exception e) {
|
159 | 165 | }
|
160 | 166 | }
|
161 | 167 |
|
| 168 | + private void deployRemoteModel( |
| 169 | + String modelId, |
| 170 | + FunctionName functionName, |
| 171 | + ActionListener<MLTaskResponse> wrappedListener, |
| 172 | + MLPredictionTaskRequest mlPredictionTaskRequest |
| 173 | + ) { |
| 174 | + String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true); |
| 175 | + if (workerNodes != null && workerNodes.length != 0) { |
| 176 | + return; |
| 177 | + } |
| 178 | + |
| 179 | + MLDeployModelRequest deployModelRequest = MLDeployModelRequest.builder().modelId(modelId).async(false).dispatchTask(true).build(); |
| 180 | + ActionListener<MLDeployModelResponse> deployModelActionListener = ActionListener.wrap(deployModelResponse -> { |
| 181 | + // Deployment failed, existing |
| 182 | + if (!deployModelResponse.getStatus().equals(MLTaskState.COMPLETED.name())) { |
| 183 | + wrappedListener |
| 184 | + .onFailure( |
| 185 | + new IllegalArgumentException( |
| 186 | + "Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + modelId + "/_deploy" |
| 187 | + ) |
| 188 | + ); |
| 189 | + return; |
| 190 | + } |
| 191 | + // The DeployModel is async, set this maximum wait time for deployment to finish |
| 192 | + long startTime = System.currentTimeMillis(); |
| 193 | + long maxDuration = 100; // 1 second in milliseconds |
| 194 | + while (workerNodes == null || workerNodes.length == 0) { |
| 195 | + long currentTime = System.currentTimeMillis(); |
| 196 | + if (currentTime - startTime >= maxDuration) { |
| 197 | + log.info("Wait Time limit reached. Exiting loop."); |
| 198 | + break; |
| 199 | + } |
| 200 | + } |
| 201 | + executePredict(mlPredictionTaskRequest, wrappedListener, modelId); |
| 202 | + }, wrappedListener::onFailure); |
| 203 | + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, deployModelActionListener); |
| 204 | + } |
| 205 | + |
162 | 206 | private void executePredict(
|
163 | 207 | MLPredictionTaskRequest mlPredictionTaskRequest,
|
164 | 208 | ActionListener<MLTaskResponse> wrappedListener,
|
|
0 commit comments