Skip to content

Commit b19c415

Browse files
committed
avoid race condition in syncup model state refresh and handle NP of IsAutoDeployEnabled
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 950f864 commit b19c415

File tree

3 files changed

+44
-44
lines changed

3 files changed

+44
-44
lines changed

plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java

+40-29
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
1010
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
1111
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
12+
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
1213

1314
import java.time.Instant;
1415
import java.util.ArrayList;
@@ -41,8 +42,9 @@
4142
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
4243
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
4344
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
44-
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
45-
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
45+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
46+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
47+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
4648
import org.opensearch.ml.engine.encryptor.Encryptor;
4749
import org.opensearch.ml.engine.indices.MLIndicesHandler;
4850
import org.opensearch.search.SearchHit;
@@ -97,6 +99,9 @@ public void run() {
9799
// gather running model/tasks on nodes
98100
client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> {
99101
List<MLSyncUpNodeResponse> responses = r.getNodes();
102+
if (r.failures() != null && r.failures().size() != 0) {
103+
log.debug("Received {} failures in the sync up response on nodes", r.failures().size());
104+
}
100105
// key is model id, value is set of worker node ids
101106
Map<String, Set<String>> modelWorkerNodes = new HashMap<>();
102107
// key is task id, value is set of worker node ids
@@ -143,7 +148,6 @@ public void run() {
143148
if (modelWorkerNodes.containsKey(modelId)
144149
&& expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) {
145150
// this model has expired in all the nodes
146-
modelWorkerNodes.remove(modelId);
147151
modelsToUndeploy.add(modelId);
148152
}
149153
}
@@ -168,37 +172,44 @@ public void run() {
168172
MLSyncUpInput syncUpInput = inputBuilder.build();
169173
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
170174
// sync up running model/tasks on nodes
171-
client
172-
.execute(
173-
MLSyncUpAction.INSTANCE,
174-
syncUpRequest,
175-
ActionListener.wrap(re -> { log.debug("sync model routing job finished"); }, ex -> {
176-
log.error("Failed to sync model routing", ex);
177-
})
178-
);
179-
// Undeploy expired models
180-
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);
175+
client.execute(MLSyncUpAction.INSTANCE, syncUpRequest, ActionListener.wrap(re -> {
176+
log.debug("sync model routing job finished");
177+
if (!modelsToUndeploy.isEmpty()) {
178+
// Undeploy expired models
179+
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes, deployingModels);
180+
return;
181+
}
182+
// refresh model status
183+
mlIndicesHandler
184+
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
185+
log.error("Failed to init model index", e);
186+
}));
187+
}, ex -> { log.error("Failed to sync model routing", ex); }));
188+
}, e -> { log.error("Failed to sync model routing", e); }));
189+
}
190+
191+
private void undeployExpiredModels(
192+
Set<String> expiredModels,
193+
Map<String, Set<String>> modelWorkerNodes,
194+
Map<String, Set<String>> deployingModels
195+
) {
196+
String[] targetNodeIds = getAllNodes(clusterService);
197+
MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest(
198+
expiredModels.toArray(new String[expiredModels.size()]),
199+
targetNodeIds
200+
);
201+
202+
client.execute(MLUndeployModelsAction.INSTANCE, mlUndeployModelsRequest, ActionListener.wrap(r -> {
203+
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse();
204+
if (mlUndeployModelNodesResponse.failures() != null && mlUndeployModelNodesResponse.failures().size() != 0) {
205+
log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures());
206+
}
181207

182-
// refresh model status
183208
mlIndicesHandler
184209
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
185210
log.error("Failed to init model index", e);
186211
}));
187-
}, e -> { log.error("Failed to sync model routing", e); }));
188-
}
189-
190-
private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
191-
expiredModels.forEach(modelId -> {
192-
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);
193-
194-
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
195-
targetNodeIds,
196-
new String[] { modelId }
197-
);
198-
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
199-
log.debug("model {} is un_deployed", modelId);
200-
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
201-
});
212+
}, e -> { log.error("Failed to undeploy models {}", expiredModels, e); }));
202213
}
203214

204215
@VisibleForTesting

plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java

+3-14
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1010
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
1111
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
12+
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
1213

1314
import java.io.IOException;
14-
import java.util.ArrayList;
15-
import java.util.Iterator;
1615
import java.util.List;
1716
import java.util.Locale;
1817

1918
import org.apache.commons.lang3.ArrayUtils;
2019
import org.opensearch.client.node.NodeClient;
21-
import org.opensearch.cluster.node.DiscoveryNode;
2220
import org.opensearch.cluster.service.ClusterService;
2321
import org.opensearch.common.settings.Settings;
2422
import org.opensearch.core.xcontent.XContentParser;
@@ -102,24 +100,15 @@ MLUndeployModelsRequest getRequest(RestRequest request) throws IOException {
102100
}
103101
targetNodeIds = nodeIds;
104102
} else {
105-
targetNodeIds = getAllNodes();
103+
targetNodeIds = getAllNodes(clusterService);
106104
}
107105
if (ArrayUtils.isNotEmpty(modelIds)) {
108106
targetModelIds = modelIds;
109107
}
110108
} else {
111-
targetNodeIds = getAllNodes();
109+
targetNodeIds = getAllNodes(clusterService);
112110
}
113111

114112
return new MLUndeployModelsRequest(targetModelIds, targetNodeIds);
115113
}
116-
117-
private String[] getAllNodes() {
118-
Iterator<DiscoveryNode> iterator = clusterService.state().nodes().iterator();
119-
List<String> nodeIds = new ArrayList<>();
120-
while (iterator.hasNext()) {
121-
nodeIds.add(iterator.next().getId());
122-
}
123-
return nodeIds.toArray(new String[0]);
124-
}
125114
}

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
263263
}
264264

265265
private boolean checkModelAutoDeployEnabled(MLModel mlModel) {
266-
if (mlModel.getDeploySetting() == null) {
266+
if (mlModel.getDeploySetting() == null || mlModel.getDeploySetting().getIsAutoDeployEnabled() == null) {
267267
return true;
268268
}
269269
return mlModel.getDeploySetting().getIsAutoDeployEnabled();

0 commit comments

Comments
 (0)