9
9
import static org .opensearch .ml .common .CommonValue .MASTER_KEY ;
10
10
import static org .opensearch .ml .common .CommonValue .ML_CONFIG_INDEX ;
11
11
import static org .opensearch .ml .common .CommonValue .ML_MODEL_INDEX ;
12
+ import static org .opensearch .ml .utils .RestActionUtils .getAllNodes ;
12
13
13
14
import java .time .Instant ;
14
15
import java .util .ArrayList ;
41
42
import org .opensearch .ml .common .transport .sync .MLSyncUpInput ;
42
43
import org .opensearch .ml .common .transport .sync .MLSyncUpNodeResponse ;
43
44
import org .opensearch .ml .common .transport .sync .MLSyncUpNodesRequest ;
44
- import org .opensearch .ml .common .transport .undeploy .MLUndeployModelAction ;
45
- import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesRequest ;
45
+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesResponse ;
46
+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsAction ;
47
+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsRequest ;
46
48
import org .opensearch .ml .engine .encryptor .Encryptor ;
47
49
import org .opensearch .ml .engine .indices .MLIndicesHandler ;
48
50
import org .opensearch .search .SearchHit ;
@@ -97,6 +99,9 @@ public void run() {
97
99
// gather running model/tasks on nodes
98
100
client .execute (MLSyncUpAction .INSTANCE , gatherInfoRequest , ActionListener .wrap (r -> {
99
101
List <MLSyncUpNodeResponse > responses = r .getNodes ();
102
+ if (r .failures () != null && r .failures ().size () != 0 ) {
103
+ log .debug ("Received {} failures in the sync up response on nodes" , r .failures ().size ());
104
+ }
100
105
// key is model id, value is set of worker node ids
101
106
Map <String , Set <String >> modelWorkerNodes = new HashMap <>();
102
107
// key is task id, value is set of worker node ids
@@ -143,7 +148,6 @@ public void run() {
143
148
if (modelWorkerNodes .containsKey (modelId )
144
149
&& expiredModelToNodes .get (modelId ).size () == modelWorkerNodes .get (modelId ).size ()) {
145
150
// this model has expired in all the nodes
146
- modelWorkerNodes .remove (modelId );
147
151
modelsToUndeploy .add (modelId );
148
152
}
149
153
}
@@ -168,37 +172,44 @@ public void run() {
168
172
MLSyncUpInput syncUpInput = inputBuilder .build ();
169
173
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest (allNodes , syncUpInput );
170
174
// sync up running model/tasks on nodes
171
- client
172
- .execute (
173
- MLSyncUpAction .INSTANCE ,
174
- syncUpRequest ,
175
- ActionListener .wrap (re -> { log .debug ("sync model routing job finished" ); }, ex -> {
176
- log .error ("Failed to sync model routing" , ex );
177
- })
178
- );
179
- // Undeploy expired models
180
- undeployExpiredModels (modelsToUndeploy , modelWorkerNodes );
175
+ client .execute (MLSyncUpAction .INSTANCE , syncUpRequest , ActionListener .wrap (re -> {
176
+ log .debug ("sync model routing job finished" );
177
+ if (!modelsToUndeploy .isEmpty ()) {
178
+ // Undeploy expired models
179
+ undeployExpiredModels (modelsToUndeploy , modelWorkerNodes , deployingModels );
180
+ return ;
181
+ }
182
+ // refresh model status
183
+ mlIndicesHandler
184
+ .initModelIndexIfAbsent (ActionListener .wrap (res -> { refreshModelState (modelWorkerNodes , deployingModels ); }, e -> {
185
+ log .error ("Failed to init model index" , e );
186
+ }));
187
+ }, ex -> { log .error ("Failed to sync model routing" , ex ); }));
188
+ }, e -> { log .error ("Failed to sync model routing" , e ); }));
189
+ }
190
+
191
+ private void undeployExpiredModels (
192
+ Set <String > expiredModels ,
193
+ Map <String , Set <String >> modelWorkerNodes ,
194
+ Map <String , Set <String >> deployingModels
195
+ ) {
196
+ String [] targetNodeIds = getAllNodes (clusterService );
197
+ MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest (
198
+ expiredModels .toArray (new String [expiredModels .size ()]),
199
+ targetNodeIds
200
+ );
201
+
202
+ client .execute (MLUndeployModelsAction .INSTANCE , mlUndeployModelsRequest , ActionListener .wrap (r -> {
203
+ MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r .getResponse ();
204
+ if (mlUndeployModelNodesResponse .failures () != null && mlUndeployModelNodesResponse .failures ().size () != 0 ) {
205
+ log .debug ("Received failures in undeploying expired models" , mlUndeployModelNodesResponse .failures ());
206
+ }
181
207
182
- // refresh model status
183
208
mlIndicesHandler
184
209
.initModelIndexIfAbsent (ActionListener .wrap (res -> { refreshModelState (modelWorkerNodes , deployingModels ); }, e -> {
185
210
log .error ("Failed to init model index" , e );
186
211
}));
187
- }, e -> { log .error ("Failed to sync model routing" , e ); }));
188
- }
189
-
190
- private void undeployExpiredModels (Set <String > expiredModels , Map <String , Set <String >> modelWorkerNodes ) {
191
- expiredModels .forEach (modelId -> {
192
- String [] targetNodeIds = modelWorkerNodes .keySet ().toArray (new String [0 ]);
193
-
194
- MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest (
195
- targetNodeIds ,
196
- new String [] { modelId }
197
- );
198
- client .execute (MLUndeployModelAction .INSTANCE , mlUndeployModelNodesRequest , ActionListener .wrap (r -> {
199
- log .debug ("model {} is un_deployed" , modelId );
200
- }, e -> { log .error ("Failed to undeploy model {}" , modelId , e ); }));
201
- });
212
+ }, e -> { log .error ("Failed to undeploy models {}" , expiredModels , e ); }));
202
213
}
203
214
204
215
@ VisibleForTesting
0 commit comments