Skip to content

Commit 4f87254

Browse files
committed
auto deployment for remote models (opensearch-project#2206)
* auto deployment for remote models Signed-off-by: Xun Zhang <xunzh@amazon.com> * add auto deploy feature flag Signed-off-by: Xun Zhang <xunzh@amazon.com> * add eligible node check and avoid over-deployment Signed-off-by: Xun Zhang <xunzh@amazon.com> * dispatch local deploy Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 045915c commit 4f87254

17 files changed

+276
-54
lines changed

common/src/main/java/org/opensearch/ml/common/FunctionName.java

+8
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,12 @@ public static FunctionName from(String value) {
5252
public static boolean isDLModel(FunctionName functionName) {
5353
return DL_MODELS.contains(functionName);
5454
}
55+
56+
public static boolean needDeployFirst(FunctionName functionName) {
57+
return DL_MODELS.contains(functionName) || functionName == REMOTE;
58+
}
59+
60+
public static boolean isAutoDeployEnabled(boolean autoDeploymentEnabled, FunctionName functionName) {
61+
return autoDeploymentEnabled && functionName == FunctionName.REMOTE;
62+
}
5563
}

common/src/main/java/org/opensearch/ml/common/MLModel.java

-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ public class MLModel implements ToXContentObject {
116116
private Integer totalChunks; // model chunk doc only
117117
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
118118
private Integer currentWorkerNodeCount; // model is deployed to how many nodes
119-
120119
private String[] planningWorkerNodes; // plan to deploy model to these nodes
121120
private boolean deployToAllNodes;
122121

common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,4 @@ public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest
116116

117117
}
118118

119-
}
119+
}

plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,12 @@ private void deployModel(
207207
Set<String> allEligibleNodeIds = Arrays.stream(allEligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet());
208208

209209
List<DiscoveryNode> eligibleNodes = new ArrayList<>();
210-
List<String> nodeIds = new ArrayList<>();
210+
List<String> eligibleNodeIds = new ArrayList<>();
211211
if (!deployToAllNodes) {
212212
for (String nodeId : targetNodeIds) {
213213
if (allEligibleNodeIds.contains(nodeId)) {
214214
eligibleNodes.add(nodeMapping.get(nodeId));
215-
nodeIds.add(nodeId);
215+
eligibleNodeIds.add(nodeId);
216216
}
217217
}
218218
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm());
@@ -234,15 +234,15 @@ private void deployModel(
234234
}
235235
}
236236
} else {
237-
nodeIds.addAll(allEligibleNodeIds);
237+
eligibleNodeIds.addAll(allEligibleNodeIds);
238238
eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
239239
}
240-
if (nodeIds.size() == 0) {
240+
if (eligibleNodeIds.size() == 0) {
241241
wrappedListener.onFailure(new IllegalArgumentException("no eligible node found"));
242242
return;
243243
}
244244

245-
log.info("Will deploy model on these nodes: {}", String.join(",", nodeIds));
245+
log.info("Will deploy model on these nodes: {}", String.join(",", eligibleNodeIds));
246246
String localNodeId = clusterService.localNode().getId();
247247

248248
FunctionName algorithm = mlModel.getAlgorithm();
@@ -258,18 +258,18 @@ private void deployModel(
258258
.createTime(Instant.now())
259259
.lastUpdateTime(Instant.now())
260260
.state(MLTaskState.CREATED)
261-
.workerNodes(nodeIds)
261+
.workerNodes(eligibleNodeIds)
262262
.build();
263263
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
264264
String taskId = response.getId();
265265
mlTask.setTaskId(taskId);
266266
if (algorithm == FunctionName.REMOTE) {
267-
mlTaskManager.add(mlTask, nodeIds);
267+
mlTaskManager.add(mlTask, eligibleNodeIds);
268268
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
269269
return;
270270
}
271271
try {
272-
mlTaskManager.add(mlTask, nodeIds);
272+
mlTaskManager.add(mlTask, eligibleNodeIds);
273273
wrappedListener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
274274
threadPool
275275
.executor(DEPLOY_THREAD_POOL)

plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java

+13-5
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,19 @@ private void deployModel(
222222
try {
223223
log.debug("start deploying model {}", modelId);
224224
mlModelManager
225-
.deployModel(modelId, modelContentHash, functionName, deployToAllNodes, mlTask, ActionListener.runBefore(listener, () -> {
226-
if (!coordinatingNodeId.equals(localNodeId)) {
227-
mlTaskManager.remove(mlTask.getTaskId());
228-
}
229-
}));
225+
.deployModel(
226+
modelId,
227+
modelContentHash,
228+
functionName,
229+
deployToAllNodes,
230+
false,
231+
mlTask,
232+
ActionListener.runBefore(listener, () -> {
233+
if (!coordinatingNodeId.equals(localNodeId)) {
234+
mlTaskManager.remove(mlTask.getTaskId());
235+
}
236+
})
237+
);
230238
} catch (Exception e) {
231239
logException("Failed to deploy model " + modelId, e, log);
232240
listener.onFailure(e);

plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
import org.opensearch.action.ActionRequest;
2121
import org.opensearch.action.support.ActionFilters;
2222
import org.opensearch.action.support.HandledTransportAction;
23+
import org.opensearch.action.update.UpdateResponse;
2324
import org.opensearch.client.Client;
2425
import org.opensearch.cluster.node.DiscoveryNode;
2526
import org.opensearch.cluster.service.ClusterService;
2627
import org.opensearch.common.inject.Inject;
2728
import org.opensearch.common.settings.Settings;
2829
import org.opensearch.core.action.ActionListener;
30+
import org.opensearch.core.rest.RestStatus;
2931
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
3032
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
3133
import org.opensearch.ml.common.FunctionName;
@@ -163,7 +165,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
163165
updateFields.put(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, 0);
164166
}
165167
log.info("deploy model done with state: {}, model id: {}", modelState, modelId);
166-
mlModelManager.updateModel(modelId, updateFields);
168+
ActionListener updateModelListener = ActionListener.<UpdateResponse>wrap(response -> {
169+
if (response.status() == RestStatus.OK) {
170+
log.debug("Updated ML model successfully: {}, model id: {}", response.status(), modelId);
171+
} else {
172+
log.error("Failed to update ML model {}, status: {}", modelId, response.status());
173+
}
174+
}, e -> { log.error("Failed to update ML model: " + modelId, e); });
175+
mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> {
176+
mlModelManager.removeAutoDeployModel(modelId);
177+
}));
167178
}
168179
listener.onResponse(new MLForwardResponse("ok", null));
169180
break;

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+12-2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
package org.opensearch.ml.action.prediction;
77

8+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
9+
810
import org.opensearch.OpenSearchStatusException;
911
import org.opensearch.action.ActionRequest;
1012
import org.opensearch.action.support.ActionFilters;
1113
import org.opensearch.action.support.HandledTransportAction;
1214
import org.opensearch.client.Client;
1315
import org.opensearch.cluster.service.ClusterService;
1416
import org.opensearch.common.inject.Inject;
17+
import org.opensearch.common.settings.Settings;
1518
import org.opensearch.common.util.concurrent.ThreadContext;
1619
import org.opensearch.commons.authuser.User;
1720
import org.opensearch.core.action.ActionListener;
@@ -37,7 +40,7 @@
3740
import lombok.extern.log4j.Log4j2;
3841

3942
@Log4j2
40-
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
43+
@FieldDefaults(level = AccessLevel.PRIVATE)
4144
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {
4245
MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
4346
TransportService transportService;
@@ -53,6 +56,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action
5356

5457
ModelAccessControlHelper modelAccessControlHelper;
5558

59+
private volatile boolean enableAutomaticDeployment;
60+
5661
@Inject
5762
public TransportPredictionTaskAction(
5863
TransportService transportService,
@@ -63,7 +68,8 @@ public TransportPredictionTaskAction(
6368
Client client,
6469
NamedXContentRegistry xContentRegistry,
6570
MLModelManager mlModelManager,
66-
ModelAccessControlHelper modelAccessControlHelper
71+
ModelAccessControlHelper modelAccessControlHelper,
72+
Settings settings
6773
) {
6874
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
6975
this.mlPredictTaskRunner = mlPredictTaskRunner;
@@ -74,6 +80,10 @@ public TransportPredictionTaskAction(
7480
this.xContentRegistry = xContentRegistry;
7581
this.mlModelManager = mlModelManager;
7682
this.modelAccessControlHelper = modelAccessControlHelper;
83+
enableAutomaticDeployment = ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
84+
clusterService
85+
.getClusterSettings()
86+
.addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> enableAutomaticDeployment = it);
7787
}
7888

7989
@Override

plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,21 +74,21 @@ public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
7474
continue;
7575
}
7676
if (functionName == FunctionName.REMOTE) {// remote model
77-
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
77+
getEligibleNode(remoteModelEligibleNodeRoles, eligibleNodes, node);
7878
} else { // local model
7979
if (onlyRunOnMLNode) {
8080
if (MLNodeUtils.isMLNode(node)) {
8181
eligibleNodes.add(node);
8282
}
8383
} else {
84-
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
84+
getEligibleNode(localModelEligibleNodeRoles, eligibleNodes, node);
8585
}
8686
}
8787
}
8888
return eligibleNodes.toArray(new DiscoveryNode[0]);
8989
}
9090

91-
private void getEligibleNodes(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
91+
private void getEligibleNode(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
9292
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
9393
eligibleNodes.add(node);
9494
}

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

+37
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@
3232
@Log4j2
3333
public class MLModelCacheHelper {
3434
private final Map<String, MLModelCache> modelCaches;
35+
36+
private final Map<String, MLModel> autoDeployModels;
3537
private volatile Long maxRequestCount;
3638

3739
public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
3840
this.modelCaches = new ConcurrentHashMap<>();
41+
this.autoDeployModels = new ConcurrentHashMap<>();
3942

4043
maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings);
4144
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it);
@@ -67,6 +70,25 @@ public synchronized void initModelState(
6770
modelCaches.put(modelId, modelCache);
6871
}
6972

73+
public synchronized void initModelStateLocal(
74+
String modelId,
75+
MLModelState state,
76+
FunctionName functionName,
77+
List<String> targetWorkerNodes
78+
) {
79+
log.debug("init local model deployment state for model {}, state: {}", modelId, state);
80+
if (isModelRunningOnNode(modelId)) {
81+
// model state initialized
82+
return;
83+
}
84+
MLModelCache modelCache = new MLModelCache();
85+
modelCache.setModelState(state);
86+
modelCache.setFunctionName(functionName);
87+
modelCache.setTargetWorkerNodes(targetWorkerNodes);
88+
modelCache.setDeployToAllNodes(false);
89+
modelCaches.put(modelId, modelCache);
90+
}
91+
7092
/**
7193
* Set model state
7294
*
@@ -358,6 +380,7 @@ public void removeModel(String modelId) {
358380
modelCache.clear();
359381
modelCaches.remove(modelId);
360382
}
383+
autoDeployModels.remove(modelId);
361384
}
362385

363386
/**
@@ -590,4 +613,18 @@ private MLModelCache getOrCreateModelCache(String modelId) {
590613
return modelCaches.computeIfAbsent(modelId, it -> new MLModelCache());
591614
}
592615

616+
public MLModel addModelToAutoDeployCache(String modelId, MLModel model) {
617+
MLModel addedModel = autoDeployModels.computeIfAbsent(modelId, key -> model);
618+
if (addedModel == model) {
619+
log.info("Add model {} to auto deploy cache", modelId);
620+
}
621+
return addedModel;
622+
}
623+
624+
public void removeAutoDeployModel(String modelId) {
625+
MLModel removedModel = autoDeployModels.remove(modelId);
626+
if (removedModel != null) {
627+
log.info("Remove model {} from auto deploy cache", modelId);
628+
}
629+
}
593630
}

0 commit comments

Comments
 (0)