|
10 | 10 | import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
|
11 | 11 | import static org.opensearch.ml.utils.MLExceptionUtils.logException;
|
12 | 12 | import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString;
|
| 13 | +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; |
13 | 14 |
|
14 | 15 | import java.time.Instant;
|
15 | 16 | import java.util.Arrays;
|
16 | 17 | import java.util.HashMap;
|
| 18 | +import java.util.HashSet; |
| 19 | +import java.util.List; |
17 | 20 | import java.util.Map;
|
18 | 21 | import java.util.Set;
|
| 22 | +import java.util.stream.Collectors; |
19 | 23 |
|
20 | 24 | import org.opensearch.action.ActionRequest;
|
21 | 25 | import org.opensearch.action.support.ActionFilters;
|
@@ -131,7 +135,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
|
131 | 135 | syncModelWorkerNodes(modelId, functionName);
|
132 | 136 | }
|
133 | 137 |
|
134 |
| - if (workNodes == null || workNodes.size() == 0) { |
| 138 | + Set<String> workNodesRemovedFromCluster = new HashSet<>(); |
| 139 | + |
| 140 | + if (workNodes != null && !workNodes.isEmpty()) { |
| 141 | + Set<String> allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService))); |
| 142 | + |
| 143 | + workNodesRemovedFromCluster = workNodes.stream() |
| 144 | + .filter(node -> !allNodesInCluster.contains(node)) |
| 145 | + .collect(Collectors.toSet()); |
| 146 | + |
| 147 | + if (!workNodesRemovedFromCluster.isEmpty()) { |
| 148 | + workNodes.removeAll(workNodesRemovedFromCluster); |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + if (workNodes == null || workNodes.isEmpty()) { |
| 153 | + if (!workNodesRemovedFromCluster.isEmpty()) { |
| 154 | + mlTaskCache.updateWorkerNodeCount(workNodesRemovedFromCluster); |
| 155 | + mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0])); |
| 156 | + } |
135 | 157 | int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
|
136 | 158 | MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
|
137 | 159 | if (mlTaskCache.allNodeFailed()) {
|
|
0 commit comments