|
18 | 18 | import static org.mockito.Mockito.verify;
|
19 | 19 | import static org.mockito.Mockito.when;
|
20 | 20 | import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
|
| 21 | +import static org.opensearch.ml.common.CommonValue.NOT_FOUND; |
21 | 22 | import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
|
22 | 23 |
|
23 | 24 | import java.io.IOException;
|
24 | 25 | import java.util.ArrayList;
|
| 26 | +import java.util.HashMap; |
25 | 27 | import java.util.List;
|
26 | 28 | import java.util.Map;
|
27 | 29 |
|
@@ -348,6 +350,63 @@ public void testHiddenModelSuccess() {
|
348 | 350 | verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
|
349 | 351 | }
|
350 | 352 |
|
| 353 | + public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() { |
| 354 | + MLModel mlModel = MLModel |
| 355 | + .builder() |
| 356 | + .user(User.parse(USER_STRING)) |
| 357 | + .modelGroupId("111") |
| 358 | + .version("111") |
| 359 | + .name(this.modelIds[0]) |
| 360 | + .modelId(this.modelIds[0]) |
| 361 | + .algorithm(FunctionName.BATCH_RCF) |
| 362 | + .content("content") |
| 363 | + .totalChunks(2) |
| 364 | + .isHidden(true) |
| 365 | + .build(); |
| 366 | + |
| 367 | + // Mock MLModel manager response |
| 368 | + doAnswer(invocation -> { |
| 369 | + ActionListener<MLModel> listener = invocation.getArgument(4); |
| 370 | + listener.onResponse(mlModel); |
| 371 | + return null; |
| 372 | + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); |
| 373 | + |
| 374 | + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); |
| 375 | + |
| 376 | + List<MLUndeployModelNodeResponse> responseList = new ArrayList<>(); |
| 377 | + |
| 378 | + for (String nodeId : this.nodeIds) { |
| 379 | + Map<String, String> stats = new HashMap<>(); |
| 380 | + stats.put(this.modelIds[0], NOT_FOUND); |
| 381 | + MLUndeployModelNodeResponse nodeResponse = mock(MLUndeployModelNodeResponse.class); |
| 382 | + when(nodeResponse.getModelUndeployStatus()).thenReturn(stats); |
| 383 | + responseList.add(nodeResponse); |
| 384 | + } |
| 385 | + |
| 386 | + List<FailedNodeException> failuresList = new ArrayList<>(); |
| 387 | + MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); |
| 388 | + |
| 389 | + doAnswer(invocation -> { |
| 390 | + ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2); |
| 391 | + listener.onResponse(nodesResponse); |
| 392 | + return null; |
| 393 | + }).when(client).execute(any(), any(), isA(ActionListener.class)); |
| 394 | + |
| 395 | + doAnswer(invocation -> { |
| 396 | + ActionListener<BulkResponse> listener = invocation.getArgument(1); |
| 397 | + listener.onResponse(mock(BulkResponse.class)); |
| 398 | + return null; |
| 399 | + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
| 400 | + |
| 401 | + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); |
| 402 | + |
| 403 | + transportUndeployModelsAction.doExecute(task, request, actionListener); |
| 404 | + |
| 405 | + // Verify that bulk request was fired because all nodes reported "not_found" |
| 406 | + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
| 407 | + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); |
| 408 | + } |
| 409 | + |
351 | 410 | public void testHiddenModelPermissionError() {
|
352 | 411 | MLModel mlModel = MLModel
|
353 | 412 | .builder()
|
|
0 commit comments