13
13
import static org .mockito .Mockito .doReturn ;
14
14
import static org .mockito .Mockito .doThrow ;
15
15
import static org .mockito .Mockito .mock ;
16
+ import static org .mockito .Mockito .never ;
16
17
import static org .mockito .Mockito .spy ;
17
18
import static org .mockito .Mockito .verify ;
18
19
import static org .mockito .Mockito .when ;
20
+ import static org .opensearch .ml .common .CommonValue .ML_MODEL_INDEX ;
19
21
import static org .opensearch .ml .task .MLPredictTaskRunnerTests .USER_STRING ;
20
22
21
23
import java .io .IOException ;
22
24
import java .util .ArrayList ;
23
25
import java .util .List ;
26
+ import java .util .Map ;
24
27
25
28
import org .junit .Before ;
26
29
import org .junit .Rule ;
29
32
import org .mockito .Mock ;
30
33
import org .mockito .MockitoAnnotations ;
31
34
import org .opensearch .action .FailedNodeException ;
35
+ import org .opensearch .action .bulk .BulkRequest ;
36
+ import org .opensearch .action .bulk .BulkResponse ;
32
37
import org .opensearch .action .support .ActionFilters ;
38
+ import org .opensearch .action .update .UpdateRequest ;
33
39
import org .opensearch .client .Client ;
34
40
import org .opensearch .cluster .ClusterName ;
35
41
import org .opensearch .cluster .service .ClusterService ;
42
48
import org .opensearch .ml .cluster .DiscoveryNodeHelper ;
43
49
import org .opensearch .ml .common .FunctionName ;
44
50
import org .opensearch .ml .common .MLModel ;
51
+ import org .opensearch .ml .common .model .MLModelState ;
45
52
import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodeResponse ;
46
53
import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesResponse ;
47
54
import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsRequest ;
@@ -172,6 +179,129 @@ public void setup() throws IOException {
172
179
}).when (mlModelManager ).getModel (any (), any (), any (), any (), isA (ActionListener .class ));
173
180
}
174
181
182
+ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel () {
183
+ String modelId = "someModelId" ;
184
+ MLModel mlModel = MLModel
185
+ .builder ()
186
+ .user (User .parse (USER_STRING ))
187
+ .modelGroupId ("111" )
188
+ .version ("111" )
189
+ .name ("Test Model" )
190
+ .modelId (modelId )
191
+ .algorithm (FunctionName .BATCH_RCF )
192
+ .content ("content" )
193
+ .totalChunks (2 )
194
+ .isHidden (true )
195
+ .build ();
196
+
197
+ doAnswer (invocation -> {
198
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
199
+ listener .onResponse (mlModel );
200
+ return null ;
201
+ }).when (mlModelManager ).getModel (any (), any (), any (), any (), isA (ActionListener .class ));
202
+
203
+ doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
204
+
205
+ List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
206
+ List <FailedNodeException > failuresList = new ArrayList <>();
207
+ MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
208
+
209
+ // Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
210
+ doAnswer (invocation -> {
211
+ ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
212
+ listener .onResponse (nodesResponse );
213
+ return null ;
214
+ }).when (client ).execute (any (), any (), isA (ActionListener .class ));
215
+
216
+ ArgumentCaptor <BulkRequest > bulkRequestCaptor = ArgumentCaptor .forClass (BulkRequest .class );
217
+
218
+ // mock the bulk response that can be captured for inspecting the contents of the write to index
219
+ doAnswer (invocation -> {
220
+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
221
+ BulkResponse bulkResponse = mock (BulkResponse .class );
222
+ when (bulkResponse .hasFailures ()).thenReturn (false );
223
+ listener .onResponse (bulkResponse );
224
+ return null ;
225
+ }).when (client ).bulk (bulkRequestCaptor .capture (), any (ActionListener .class ));
226
+
227
+ String [] modelIds = new String [] { modelId };
228
+ String [] nodeIds = new String [] { "test_node_id1" , "test_node_id2" };
229
+ MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
230
+
231
+ transportUndeployModelsAction .doExecute (task , request , actionListener );
232
+
233
+ BulkRequest capturedBulkRequest = bulkRequestCaptor .getValue ();
234
+ assertEquals (1 , capturedBulkRequest .numberOfActions ());
235
+ UpdateRequest updateRequest = (UpdateRequest ) capturedBulkRequest .requests ().get (0 );
236
+
237
+ @ SuppressWarnings ("unchecked" )
238
+ Map <String , Object > updateDoc = updateRequest .doc ().sourceAsMap ();
239
+ String modelIdFromBulkRequest = updateRequest .id ();
240
+ String indexNameFromBulkRequest = updateRequest .index ();
241
+
242
+ assertEquals ("Check that the write happened at the model index" , ML_MODEL_INDEX , indexNameFromBulkRequest );
243
+ assertEquals ("Check that the result bulk write hit this specific modelId" , modelId , modelIdFromBulkRequest );
244
+
245
+ assertEquals (MLModelState .UNDEPLOYED .name (), updateDoc .get (MLModel .MODEL_STATE_FIELD ));
246
+ assertEquals (0 , updateDoc .get (MLModel .CURRENT_WORKER_NODE_COUNT_FIELD ));
247
+ assertEquals (0 , updateDoc .get (MLModel .PLANNING_WORKER_NODE_COUNT_FIELD ));
248
+ assertEquals (List .of (), updateDoc .get (MLModel .PLANNING_WORKER_NODES_FIELD ));
249
+ assertTrue (updateDoc .containsKey (MLModel .LAST_UPDATED_TIME_FIELD ));
250
+
251
+ verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
252
+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
253
+ }
254
+
255
+ public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel () {
256
+ String modelId = "someModelId" ;
257
+ MLModel mlModel = MLModel
258
+ .builder ()
259
+ .user (User .parse (USER_STRING ))
260
+ .modelGroupId ("111" )
261
+ .version ("111" )
262
+ .name ("Test Model" )
263
+ .modelId (modelId )
264
+ .algorithm (FunctionName .BATCH_RCF )
265
+ .content ("content" )
266
+ .totalChunks (2 )
267
+ .isHidden (true )
268
+ .build ();
269
+
270
+ doAnswer (invocation -> {
271
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
272
+ listener .onResponse (mlModel );
273
+ return null ;
274
+ }).when (mlModelManager ).getModel (any (), any (), any (), any (), isA (ActionListener .class ));
275
+
276
+ doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
277
+
278
+ List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
279
+ responseList .add (mock (MLUndeployModelNodeResponse .class ));
280
+ responseList .add (mock (MLUndeployModelNodeResponse .class ));
281
+ List <FailedNodeException > failuresList = new ArrayList <>();
282
+ failuresList .add (mock (FailedNodeException .class ));
283
+ failuresList .add (mock (FailedNodeException .class ));
284
+
285
+ MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
286
+
287
+ // Send back a response with nodes associated to the model
288
+ doAnswer (invocation -> {
289
+ ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
290
+ listener .onResponse (nodesResponse );
291
+ return null ;
292
+ }).when (client ).execute (any (), any (), isA (ActionListener .class ));
293
+
294
+ String [] modelIds = new String [] { modelId };
295
+ String [] nodeIds = new String [] { "test_node_id1" , "test_node_id2" };
296
+ MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
297
+
298
+ transportUndeployModelsAction .doExecute (task , request , actionListener );
299
+
300
+ verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
301
+ // Check that no bulk write occurred Since there were nodes servicing the model
302
+ verify (client , never ()).bulk (any (BulkRequest .class ), any (ActionListener .class ));
303
+ }
304
+
175
305
public void testHiddenModelSuccess () {
176
306
MLModel mlModel = MLModel
177
307
.builder ()
@@ -194,16 +324,28 @@ public void testHiddenModelSuccess() {
194
324
List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
195
325
List <FailedNodeException > failuresList = new ArrayList <>();
196
326
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
327
+
197
328
doAnswer (invocation -> {
198
329
ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
199
330
listener .onResponse (response );
200
331
return null ;
201
332
}).when (client ).execute (any (), any (), isA (ActionListener .class ));
202
333
334
+ // Mock the client.bulk call
335
+ doAnswer (invocation -> {
336
+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
337
+ BulkResponse bulkResponse = mock (BulkResponse .class );
338
+ when (bulkResponse .hasFailures ()).thenReturn (false );
339
+ listener .onResponse (bulkResponse );
340
+ return null ;
341
+ }).when (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
342
+
203
343
doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
204
344
MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
205
345
transportUndeployModelsAction .doExecute (task , request , actionListener );
346
+
206
347
verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
348
+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
207
349
}
208
350
209
351
public void testHiddenModelPermissionError () {
@@ -257,9 +399,19 @@ public void testDoExecute() {
257
399
listener .onResponse (response );
258
400
return null ;
259
401
}).when (client ).execute (any (), any (), isA (ActionListener .class ));
402
+ // Mock the client.bulk call
403
+ doAnswer (invocation -> {
404
+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
405
+ BulkResponse bulkResponse = mock (BulkResponse .class );
406
+ when (bulkResponse .hasFailures ()).thenReturn (false );
407
+ listener .onResponse (bulkResponse );
408
+ return null ;
409
+ }).when (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
410
+
260
411
MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
261
412
transportUndeployModelsAction .doExecute (task , request , actionListener );
262
413
verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
414
+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
263
415
}
264
416
265
417
public void testDoExecute_modelAccessControl_notEnabled () {
0 commit comments