10
10
import static org .opensearch .ml .task .MLTaskManager .TASK_SEMAPHORE_TIMEOUT ;
11
11
import static org .opensearch .ml .utils .MLExceptionUtils .logException ;
12
12
import static org .opensearch .ml .utils .MLExceptionUtils .toJsonString ;
13
+ import static org .opensearch .ml .utils .RestActionUtils .getAllNodes ;
13
14
14
15
import java .time .Instant ;
15
16
import java .util .Arrays ;
16
17
import java .util .HashMap ;
18
+ import java .util .HashSet ;
19
+ import java .util .List ;
17
20
import java .util .Map ;
18
21
import java .util .Set ;
22
+ import java .util .stream .Collectors ;
19
23
20
24
import org .opensearch .action .ActionRequest ;
21
25
import org .opensearch .action .support .ActionFilters ;
@@ -131,10 +135,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
131
135
syncModelWorkerNodes (modelId , functionName );
132
136
}
133
137
134
- if (workNodes == null || workNodes .size () == 0 ) {
138
+ Set <String > workNodesRemovedFromCluster = new HashSet <>();
139
+
140
+ if (workNodes != null && !workNodes .isEmpty ()) {
141
+ Set <String > allNodesInCluster = new HashSet <>(List .of (getAllNodes (clusterService )));
142
+
143
+ workNodesRemovedFromCluster = workNodes
144
+ .stream ()
145
+ .filter (node -> !allNodesInCluster .contains (node ))
146
+ .collect (Collectors .toSet ());
147
+
148
+ if (!workNodesRemovedFromCluster .isEmpty ()) {
149
+ workNodes .removeAll (workNodesRemovedFromCluster );
150
+ }
151
+ }
152
+
153
+ if (workNodes == null || workNodes .isEmpty ()) {
154
+ if (!workNodesRemovedFromCluster .isEmpty ()) {
155
+ mlTaskCache .updateWorkerNode (workNodesRemovedFromCluster );
156
+ mlModelManager .removeModelWorkerNode (modelId , false , workNodesRemovedFromCluster .toArray (new String [0 ]));
157
+ }
135
158
int currentWorkerNodeCount = mlTaskCache .getWorkerNodeSize ();
136
159
MLTaskState taskState = mlTaskCache .hasError () ? MLTaskState .COMPLETED_WITH_ERROR : MLTaskState .COMPLETED ;
137
- if (mlTaskCache .allNodeFailed ()) {
160
+ if (mlTaskCache .allNodeFailed () || mlTaskCache . getWorkerNodeSize () == 0 ) {
138
161
taskState = MLTaskState .FAILED ;
139
162
currentWorkerNodeCount = 0 ;
140
163
} else {
@@ -150,11 +173,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
150
173
mlTaskManager .updateMLTask (taskId , builder .build (), TASK_SEMAPHORE_TIMEOUT , true );
151
174
152
175
MLModelState modelState ;
153
- if (!mlTaskCache .allNodeFailed ()) {
154
- modelState = mlTaskCache .hasError () ? MLModelState .PARTIALLY_DEPLOYED : MLModelState .DEPLOYED ;
155
- } else {
176
+ if (mlTaskCache .allNodeFailed () || mlTaskCache .getWorkerNodeSize () == 0 ) {
156
177
modelState = MLModelState .DEPLOY_FAILED ;
157
178
log .error ("deploy model failed on all nodes, model id: {}" , modelId );
179
+ } else {
180
+ modelState = mlTaskCache .hasError () ? MLModelState .PARTIALLY_DEPLOYED : MLModelState .DEPLOYED ;
158
181
}
159
182
Map <String , Object > updateFields = new HashMap <>();
160
183
updateFields .put (MLModel .MODEL_STATE_FIELD , modelState );
0 commit comments