9
9
import static org .opensearch .ml .common .CommonValue .MASTER_KEY ;
10
10
import static org .opensearch .ml .common .CommonValue .ML_CONFIG_INDEX ;
11
11
import static org .opensearch .ml .common .CommonValue .ML_MODEL_INDEX ;
12
+ import static org .opensearch .ml .utils .RestActionUtils .getAllNodes ;
12
13
13
14
import java .time .Instant ;
14
15
import java .util .ArrayList ;
41
42
import org .opensearch .ml .common .transport .sync .MLSyncUpInput ;
42
43
import org .opensearch .ml .common .transport .sync .MLSyncUpNodeResponse ;
43
44
import org .opensearch .ml .common .transport .sync .MLSyncUpNodesRequest ;
44
- import org .opensearch .ml .common .transport .undeploy .MLUndeployModelAction ;
45
- import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesRequest ;
45
+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesResponse ;
46
+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsAction ;
47
+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsRequest ;
46
48
import org .opensearch .ml .engine .encryptor .Encryptor ;
47
49
import org .opensearch .ml .engine .indices .MLIndicesHandler ;
48
50
import org .opensearch .search .SearchHit ;
@@ -97,6 +99,14 @@ public void run() {
97
99
// gather running model/tasks on nodes
98
100
client .execute (MLSyncUpAction .INSTANCE , gatherInfoRequest , ActionListener .wrap (r -> {
99
101
List <MLSyncUpNodeResponse > responses = r .getNodes ();
102
+ if (r .failures () != null && r .failures ().size () != 0 ) {
103
+ log
104
+ .debug (
105
+ "Received {} failures in the sync up response on nodes. Error messages are {}" ,
106
+ r .failures ().size (),
107
+ r .failures ().stream ().map (Exception ::getMessage ).collect (Collectors .joining (", " ))
108
+ );
109
+ }
100
110
// key is model id, value is set of worker node ids
101
111
Map <String , Set <String >> modelWorkerNodes = new HashMap <>();
102
112
// key is task id, value is set of worker node ids
@@ -143,7 +153,6 @@ public void run() {
143
153
if (modelWorkerNodes .containsKey (modelId )
144
154
&& expiredModelToNodes .get (modelId ).size () == modelWorkerNodes .get (modelId ).size ()) {
145
155
// this model has expired in all the nodes
146
- modelWorkerNodes .remove (modelId );
147
156
modelsToUndeploy .add (modelId );
148
157
}
149
158
}
@@ -168,37 +177,44 @@ public void run() {
168
177
MLSyncUpInput syncUpInput = inputBuilder .build ();
169
178
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest (allNodes , syncUpInput );
170
179
// sync up running model/tasks on nodes
171
- client
172
- .execute (
173
- MLSyncUpAction .INSTANCE ,
174
- syncUpRequest ,
175
- ActionListener .wrap (re -> { log .debug ("sync model routing job finished" ); }, ex -> {
176
- log .error ("Failed to sync model routing" , ex );
177
- })
178
- );
179
- // Undeploy expired models
180
- undeployExpiredModels (modelsToUndeploy , modelWorkerNodes );
180
+ client .execute (MLSyncUpAction .INSTANCE , syncUpRequest , ActionListener .wrap (re -> {
181
+ log .debug ("sync model routing job finished" );
182
+ if (!modelsToUndeploy .isEmpty ()) {
183
+ // Undeploy expired models
184
+ undeployExpiredModels (modelsToUndeploy , modelWorkerNodes , deployingModels );
185
+ return ;
186
+ }
187
+ // refresh model status
188
+ mlIndicesHandler
189
+ .initModelIndexIfAbsent (ActionListener .wrap (res -> { refreshModelState (modelWorkerNodes , deployingModels ); }, e -> {
190
+ log .error ("Failed to init model index" , e );
191
+ }));
192
+ }, ex -> { log .error ("Failed to sync model routing" , ex ); }));
193
+ }, e -> { log .error ("Failed to sync model routing" , e ); }));
194
+ }
195
+
196
+ private void undeployExpiredModels (
197
+ Set <String > expiredModels ,
198
+ Map <String , Set <String >> modelWorkerNodes ,
199
+ Map <String , Set <String >> deployingModels
200
+ ) {
201
+ String [] targetNodeIds = getAllNodes (clusterService );
202
+ MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest (
203
+ expiredModels .toArray (new String [expiredModels .size ()]),
204
+ targetNodeIds
205
+ );
206
+
207
+ client .execute (MLUndeployModelsAction .INSTANCE , mlUndeployModelsRequest , ActionListener .wrap (r -> {
208
+ MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r .getResponse ();
209
+ if (mlUndeployModelNodesResponse .failures () != null && mlUndeployModelNodesResponse .failures ().size () != 0 ) {
210
+ log .debug ("Received failures in undeploying expired models" , mlUndeployModelNodesResponse .failures ());
211
+ }
181
212
182
- // refresh model status
183
213
mlIndicesHandler
184
214
.initModelIndexIfAbsent (ActionListener .wrap (res -> { refreshModelState (modelWorkerNodes , deployingModels ); }, e -> {
185
215
log .error ("Failed to init model index" , e );
186
216
}));
187
- }, e -> { log .error ("Failed to sync model routing" , e ); }));
188
- }
189
-
190
- private void undeployExpiredModels (Set <String > expiredModels , Map <String , Set <String >> modelWorkerNodes ) {
191
- expiredModels .forEach (modelId -> {
192
- String [] targetNodeIds = modelWorkerNodes .keySet ().toArray (new String [0 ]);
193
-
194
- MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest (
195
- targetNodeIds ,
196
- new String [] { modelId }
197
- );
198
- client .execute (MLUndeployModelAction .INSTANCE , mlUndeployModelNodesRequest , ActionListener .wrap (r -> {
199
- log .debug ("model {} is un_deployed" , modelId );
200
- }, e -> { log .error ("Failed to undeploy model {}" , modelId , e ); }));
201
- });
217
+ }, e -> { log .error ("Failed to undeploy models {}" , expiredModels , e ); }));
202
218
}
203
219
204
220
@ VisibleForTesting
0 commit comments