Skip to content

Commit c5ceb48

Browse files
authored
add edge case for models that are marked as not found in cache (opensearch-project#3523)
There is a code change that requires to check the response of the model undeploy response object to check that the model has been marked as not found on all nodes. Signed-off-by: Brian Flores <iflorbri@amazon.com>
1 parent 1b8b014 commit c5ceb48

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
package org.opensearch.ml.action.undeploy;
77

88
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
9+
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
910

1011
import java.time.Instant;
1112
import java.util.Arrays;
1213
import java.util.List;
14+
import java.util.Map;
1315
import java.util.stream.Collectors;
1416

1517
import org.opensearch.ExceptionsHelper;
@@ -198,7 +200,19 @@ private void undeployModels(
198200
* Having this change enables a check that this edge case occurs along with having access to the model id
199201
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
200202
*/
201-
if (response.getNodes().isEmpty()) {
203+
boolean modelNotFoundInNodesCache = response.getNodes().stream().allMatch(nodeResponse -> {
204+
Map<String, String> status = nodeResponse.getModelUndeployStatus();
205+
if (status == null)
206+
return false;
207+
// Stream is used to catch all models edge case but only one is ever undeployed
208+
boolean modelCacheMissForModelIds = Arrays.stream(modelIds).allMatch(modelId -> {
209+
String modelStatus = status.get(modelId);
210+
return modelStatus != null && modelStatus.equalsIgnoreCase(NOT_FOUND);
211+
});
212+
213+
return modelCacheMissForModelIds;
214+
});
215+
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
202216
bulkSetModelIndexToUndeploy(modelIds, listener, response);
203217
return;
204218
}

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

+59
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
import static org.mockito.Mockito.verify;
1919
import static org.mockito.Mockito.when;
2020
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
21+
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
2122
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
2223

2324
import java.io.IOException;
2425
import java.util.ArrayList;
26+
import java.util.HashMap;
2527
import java.util.List;
2628
import java.util.Map;
2729

@@ -348,6 +350,63 @@ public void testHiddenModelSuccess() {
348350
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
349351
}
350352

353+
public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
354+
MLModel mlModel = MLModel
355+
.builder()
356+
.user(User.parse(USER_STRING))
357+
.modelGroupId("111")
358+
.version("111")
359+
.name(this.modelIds[0])
360+
.modelId(this.modelIds[0])
361+
.algorithm(FunctionName.BATCH_RCF)
362+
.content("content")
363+
.totalChunks(2)
364+
.isHidden(true)
365+
.build();
366+
367+
// Mock MLModel manager response
368+
doAnswer(invocation -> {
369+
ActionListener<MLModel> listener = invocation.getArgument(4);
370+
listener.onResponse(mlModel);
371+
return null;
372+
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
373+
374+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
375+
376+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
377+
378+
for (String nodeId : this.nodeIds) {
379+
Map<String, String> stats = new HashMap<>();
380+
stats.put(this.modelIds[0], NOT_FOUND);
381+
MLUndeployModelNodeResponse nodeResponse = mock(MLUndeployModelNodeResponse.class);
382+
when(nodeResponse.getModelUndeployStatus()).thenReturn(stats);
383+
responseList.add(nodeResponse);
384+
}
385+
386+
List<FailedNodeException> failuresList = new ArrayList<>();
387+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
388+
389+
doAnswer(invocation -> {
390+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
391+
listener.onResponse(nodesResponse);
392+
return null;
393+
}).when(client).execute(any(), any(), isA(ActionListener.class));
394+
395+
doAnswer(invocation -> {
396+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
397+
listener.onResponse(mock(BulkResponse.class));
398+
return null;
399+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
400+
401+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
402+
403+
transportUndeployModelsAction.doExecute(task, request, actionListener);
404+
405+
// Verify that bulk request was fired because all nodes reported "not_found"
406+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
407+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
408+
}
409+
351410
public void testHiddenModelPermissionError() {
352411
MLModel mlModel = MLModel
353412
.builder()

0 commit comments

Comments
 (0)