Skip to content

Commit 21bf079

Browse files
authored
avoid race condition in syncup model state refresh and handle NP of I… (opensearch-project#2405)
* avoid race condition in syncup model state refresh and handle NP of IsAutoDeployEnabled Signed-off-by: Xun Zhang <xunzh@amazon.com> * log the error message from syncUp response Signed-off-by: Xun Zhang <xunzh@amazon.com> * include the syncup response error messages as a string to help debug Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 950f864 commit 21bf079

File tree

3 files changed

+49
-44
lines changed

3 files changed

+49
-44
lines changed

plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java

+45-29
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
1010
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
1111
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
12+
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
1213

1314
import java.time.Instant;
1415
import java.util.ArrayList;
@@ -41,8 +42,9 @@
4142
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
4243
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
4344
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
44-
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
45-
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
45+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
46+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
47+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
4648
import org.opensearch.ml.engine.encryptor.Encryptor;
4749
import org.opensearch.ml.engine.indices.MLIndicesHandler;
4850
import org.opensearch.search.SearchHit;
@@ -97,6 +99,14 @@ public void run() {
9799
// gather running model/tasks on nodes
98100
client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> {
99101
List<MLSyncUpNodeResponse> responses = r.getNodes();
102+
if (r.failures() != null && r.failures().size() != 0) {
103+
log
104+
.debug(
105+
"Received {} failures in the sync up response on nodes. Error messages are {}",
106+
r.failures().size(),
107+
r.failures().stream().map(Exception::getMessage).collect(Collectors.joining(", "))
108+
);
109+
}
100110
// key is model id, value is set of worker node ids
101111
Map<String, Set<String>> modelWorkerNodes = new HashMap<>();
102112
// key is task id, value is set of worker node ids
@@ -143,7 +153,6 @@ public void run() {
143153
if (modelWorkerNodes.containsKey(modelId)
144154
&& expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) {
145155
// this model has expired in all the nodes
146-
modelWorkerNodes.remove(modelId);
147156
modelsToUndeploy.add(modelId);
148157
}
149158
}
@@ -168,37 +177,44 @@ public void run() {
168177
MLSyncUpInput syncUpInput = inputBuilder.build();
169178
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
170179
// sync up running model/tasks on nodes
171-
client
172-
.execute(
173-
MLSyncUpAction.INSTANCE,
174-
syncUpRequest,
175-
ActionListener.wrap(re -> { log.debug("sync model routing job finished"); }, ex -> {
176-
log.error("Failed to sync model routing", ex);
177-
})
178-
);
179-
// Undeploy expired models
180-
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);
180+
client.execute(MLSyncUpAction.INSTANCE, syncUpRequest, ActionListener.wrap(re -> {
181+
log.debug("sync model routing job finished");
182+
if (!modelsToUndeploy.isEmpty()) {
183+
// Undeploy expired models
184+
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes, deployingModels);
185+
return;
186+
}
187+
// refresh model status
188+
mlIndicesHandler
189+
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
190+
log.error("Failed to init model index", e);
191+
}));
192+
}, ex -> { log.error("Failed to sync model routing", ex); }));
193+
}, e -> { log.error("Failed to sync model routing", e); }));
194+
}
195+
196+
private void undeployExpiredModels(
197+
Set<String> expiredModels,
198+
Map<String, Set<String>> modelWorkerNodes,
199+
Map<String, Set<String>> deployingModels
200+
) {
201+
String[] targetNodeIds = getAllNodes(clusterService);
202+
MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest(
203+
expiredModels.toArray(new String[expiredModels.size()]),
204+
targetNodeIds
205+
);
206+
207+
client.execute(MLUndeployModelsAction.INSTANCE, mlUndeployModelsRequest, ActionListener.wrap(r -> {
208+
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse();
209+
if (mlUndeployModelNodesResponse.failures() != null && mlUndeployModelNodesResponse.failures().size() != 0) {
210+
log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures());
211+
}
181212

182-
// refresh model status
183213
mlIndicesHandler
184214
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
185215
log.error("Failed to init model index", e);
186216
}));
187-
}, e -> { log.error("Failed to sync model routing", e); }));
188-
}
189-
190-
private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
191-
expiredModels.forEach(modelId -> {
192-
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);
193-
194-
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
195-
targetNodeIds,
196-
new String[] { modelId }
197-
);
198-
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
199-
log.debug("model {} is un_deployed", modelId);
200-
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
201-
});
217+
}, e -> { log.error("Failed to undeploy models {}", expiredModels, e); }));
202218
}
203219

204220
@VisibleForTesting

plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java

+3-14
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1010
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
1111
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
12+
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
1213

1314
import java.io.IOException;
14-
import java.util.ArrayList;
15-
import java.util.Iterator;
1615
import java.util.List;
1716
import java.util.Locale;
1817

1918
import org.apache.commons.lang3.ArrayUtils;
2019
import org.opensearch.client.node.NodeClient;
21-
import org.opensearch.cluster.node.DiscoveryNode;
2220
import org.opensearch.cluster.service.ClusterService;
2321
import org.opensearch.common.settings.Settings;
2422
import org.opensearch.core.xcontent.XContentParser;
@@ -102,24 +100,15 @@ MLUndeployModelsRequest getRequest(RestRequest request) throws IOException {
102100
}
103101
targetNodeIds = nodeIds;
104102
} else {
105-
targetNodeIds = getAllNodes();
103+
targetNodeIds = getAllNodes(clusterService);
106104
}
107105
if (ArrayUtils.isNotEmpty(modelIds)) {
108106
targetModelIds = modelIds;
109107
}
110108
} else {
111-
targetNodeIds = getAllNodes();
109+
targetNodeIds = getAllNodes(clusterService);
112110
}
113111

114112
return new MLUndeployModelsRequest(targetModelIds, targetNodeIds);
115113
}
116-
117-
private String[] getAllNodes() {
118-
Iterator<DiscoveryNode> iterator = clusterService.state().nodes().iterator();
119-
List<String> nodeIds = new ArrayList<>();
120-
while (iterator.hasNext()) {
121-
nodeIds.add(iterator.next().getId());
122-
}
123-
return nodeIds.toArray(new String[0]);
124-
}
125114
}

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
263263
}
264264

265265
private boolean checkModelAutoDeployEnabled(MLModel mlModel) {
266-
if (mlModel.getDeploySetting() == null) {
266+
if (mlModel.getDeploySetting() == null || mlModel.getDeploySetting().getIsAutoDeployEnabled() == null) {
267267
return true;
268268
}
269269
return mlModel.getDeploySetting().getIsAutoDeployEnabled();

0 commit comments

Comments
 (0)