Skip to content

Commit 45c920f

Browse files
committed
remove nodes that left the cluster from worker nodes
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 2f54de1 commit 45c920f

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java

+23-1
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
1111
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
1212
import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString;
13+
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
1314

1415
import java.time.Instant;
1516
import java.util.Arrays;
1617
import java.util.HashMap;
18+
import java.util.HashSet;
19+
import java.util.List;
1720
import java.util.Map;
1821
import java.util.Set;
22+
import java.util.stream.Collectors;
1923

2024
import org.opensearch.action.ActionRequest;
2125
import org.opensearch.action.support.ActionFilters;
@@ -131,7 +135,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
131135
syncModelWorkerNodes(modelId, functionName);
132136
}
133137

134-
if (workNodes == null || workNodes.size() == 0) {
138+
Set<String> workNodesRemovedFromCluster = new HashSet<>();
139+
140+
if (workNodes != null && !workNodes.isEmpty()) {
141+
Set<String> allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService)));
142+
143+
workNodesRemovedFromCluster = workNodes.stream()
144+
.filter(node -> !allNodesInCluster.contains(node))
145+
.collect(Collectors.toSet());
146+
147+
if (!workNodesRemovedFromCluster.isEmpty()) {
148+
workNodes.removeAll(workNodesRemovedFromCluster);
149+
}
150+
}
151+
152+
if (workNodes == null || workNodes.isEmpty()) {
153+
if (!workNodesRemovedFromCluster.isEmpty()) {
154+
mlTaskCache.updateWorkerNodeCount(workNodesRemovedFromCluster);
155+
mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0]));
156+
}
135157
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
136158
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
137159
if (mlTaskCache.allNodeFailed()) {

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+14
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL;
1616
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
1717
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
18+
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
1819

1920
import java.time.Instant;
2021
import java.util.Arrays;
2122
import java.util.HashMap;
23+
import java.util.HashSet;
24+
import java.util.List;
2225
import java.util.Map;
26+
import java.util.Set;
2327
import java.util.UUID;
28+
import java.util.stream.Collectors;
2429

2530
import org.opensearch.OpenSearchException;
2631
import org.opensearch.OpenSearchStatusException;
@@ -158,6 +163,15 @@ public void dispatchTask(
158163
}
159164
}, listener::onFailure);
160165
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);
166+
167+
if (workerNodes != null && workerNodes.length > 0) {
168+
String[] allNodesInCluster = getAllNodes(clusterService);
169+
170+
workerNodes = Arrays.stream(workerNodes)
171+
.filter(node -> Arrays.asList(allNodesInCluster).contains(node))
172+
.toArray(String[]::new);
173+
}
174+
161175
if (workerNodes == null || workerNodes.length == 0) {
162176
if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) {
163177
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {

plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java

+5
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,9 @@ public int errorNodesCount() {
6262
public boolean allNodeFailed() {
6363
return workerNodeSize != null && errors.size() == workerNodeSize;
6464
}
65+
66+
public void updateWorkerNodeCount(Set<String> nodesRemovedFromCluster) {
67+
this.workerNodes.removeAll(nodesRemovedFromCluster);
68+
this.workerNodeSize = this.workerNodeSize - nodesRemovedFromCluster.size();
69+
}
6570
}

0 commit comments

Comments
 (0)