Skip to content

Commit 74ed99b

Browse files
fix model stuck in deploying state during node crash/cluster restart (opensearch-project#3137) (opensearch-project#3146)
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> (cherry picked from commit bb6339f) Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent caeacf7 commit 74ed99b

File tree

8 files changed

+46
-11
lines changed

8 files changed

+46
-11
lines changed

plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java

+28-5
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
1111
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
1212
import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString;
13+
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
1314

1415
import java.time.Instant;
1516
import java.util.Arrays;
1617
import java.util.HashMap;
18+
import java.util.HashSet;
19+
import java.util.List;
1720
import java.util.Map;
1821
import java.util.Set;
22+
import java.util.stream.Collectors;
1923

2024
import org.opensearch.action.ActionRequest;
2125
import org.opensearch.action.support.ActionFilters;
@@ -131,10 +135,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
131135
syncModelWorkerNodes(modelId, functionName);
132136
}
133137

134-
if (workNodes == null || workNodes.size() == 0) {
138+
Set<String> workNodesRemovedFromCluster = new HashSet<>();
139+
140+
if (workNodes != null && !workNodes.isEmpty()) {
141+
Set<String> allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService)));
142+
143+
workNodesRemovedFromCluster = workNodes
144+
.stream()
145+
.filter(node -> !allNodesInCluster.contains(node))
146+
.collect(Collectors.toSet());
147+
148+
if (!workNodesRemovedFromCluster.isEmpty()) {
149+
workNodes.removeAll(workNodesRemovedFromCluster);
150+
}
151+
}
152+
153+
if (workNodes == null || workNodes.isEmpty()) {
154+
if (!workNodesRemovedFromCluster.isEmpty()) {
155+
mlTaskCache.updateWorkerNode(workNodesRemovedFromCluster);
156+
mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0]));
157+
}
135158
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
136159
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
137-
if (mlTaskCache.allNodeFailed()) {
160+
if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
138161
taskState = MLTaskState.FAILED;
139162
currentWorkerNodeCount = 0;
140163
} else {
@@ -150,11 +173,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
150173
mlTaskManager.updateMLTask(taskId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);
151174

152175
MLModelState modelState;
153-
if (!mlTaskCache.allNodeFailed()) {
154-
modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_DEPLOYED : MLModelState.DEPLOYED;
155-
} else {
176+
if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
156177
modelState = MLModelState.DEPLOY_FAILED;
157178
log.error("deploy model failed on all nodes, model id: {}", modelId);
179+
} else {
180+
modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_DEPLOYED : MLModelState.DEPLOYED;
158181
}
159182
Map<String, Object> updateFields = new HashMap<>();
160183
updateFields.put(MLModel.MODEL_STATE_FIELD, modelState);

plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java

+5
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,9 @@ public int errorNodesCount() {
6262
public boolean allNodeFailed() {
6363
return workerNodeSize != null && errors.size() == workerNodeSize;
6464
}
65+
66+
public void updateWorkerNode(Set<String> nodesRemovedFromCluster) {
67+
this.workerNodes.removeAll(nodesRemovedFromCluster);
68+
this.workerNodeSize = this.workerNodeSize - nodesRemovedFromCluster.size();
69+
}
6570
}

plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java

+7
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE;
3030
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;
3131
import static org.opensearch.ml.utils.TestHelper.clusterSetting;
32+
import static org.opensearch.ml.utils.TestHelper.setupTestClusterState;
3233

3334
import java.util.Arrays;
3435
import java.util.HashSet;
@@ -43,6 +44,7 @@
4344
import org.opensearch.Version;
4445
import org.opensearch.action.support.ActionFilters;
4546
import org.opensearch.client.Client;
47+
import org.opensearch.cluster.ClusterState;
4648
import org.opensearch.cluster.node.DiscoveryNode;
4749
import org.opensearch.cluster.service.ClusterService;
4850
import org.opensearch.common.settings.ClusterSettings;
@@ -94,6 +96,8 @@ public class TransportForwardActionTests extends OpenSearchTestCase {
9496

9597
private TransportForwardAction forwardAction;
9698

99+
private ClusterState testState;
100+
97101
Settings settings = Settings
98102
.builder()
99103
.put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true)
@@ -137,6 +141,9 @@ public void setup() {
137141
)
138142
);
139143

144+
testState = setupTestClusterState("test_node_id2");
145+
when(clusterService.state()).thenReturn(testState);
146+
140147
node1 = new DiscoveryNode(nodeId1, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
141148
node2 = new DiscoveryNode(nodeId2, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
142149

plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public void setup() throws IOException {
118118
encryptor = spy(new EncryptorImpl(null));
119119
syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor);
120120

121-
testState = setupTestClusterState();
121+
testState = setupTestClusterState("node");
122122
when(clusterService.state()).thenReturn(testState);
123123

124124
doAnswer(invocation -> {

plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ public void setup() throws IOException {
151151
.build();
152152

153153
clusterName = new ClusterName("test cluster");
154-
testState = setupTestClusterState();
154+
testState = setupTestClusterState("node");
155155
when(clusterService.state()).thenReturn(testState);
156156

157157
doAnswer(invocation -> {

plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ public void setup() throws IOException {
147147
roleSet,
148148
Version.CURRENT
149149
);
150-
testState = setupTestClusterState();
150+
testState = setupTestClusterState("node");
151151
when(clusterService.state()).thenReturn(testState);
152152

153153
clusterName = new ClusterName(clusterNameStr);

plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public class RestMLUndeployModelActionTests extends OpenSearchTestCase {
6767
@Before
6868
public void setup() {
6969
MockitoAnnotations.openMocks(this);
70-
testState = setupTestClusterState();
70+
testState = setupTestClusterState("node");
7171
when(clusterService.state()).thenReturn(testState);
7272
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
7373
restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);

plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -461,11 +461,11 @@ public static ClusterState state(int numDataNodes, String indexName, String mapp
461461
return state(new ClusterName("test"), indexName, mapping, clusterManagerNode, clusterManagerNode, allNodes);
462462
}
463463

464-
public static ClusterState setupTestClusterState() {
464+
public static ClusterState setupTestClusterState(String nodeId) {
465465
Set<DiscoveryNodeRole> roleSet = new HashSet<>();
466466
roleSet.add(DiscoveryNodeRole.DATA_ROLE);
467467
DiscoveryNode node = new DiscoveryNode(
468-
"node",
468+
nodeId,
469469
new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()),
470470
new HashMap<>(),
471471
roleSet,

0 commit comments

Comments
 (0)