Skip to content

Commit 8090c3d

Browse files
authored
Merge branch 'opensearch-project:main' into main_interfacefix
2 parents dfb375a + c5ceb48 commit 8090c3d

File tree

6 files changed

+78
-5
lines changed

6 files changed

+78
-5
lines changed

client/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies {
1818
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
1919
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
2020
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
21-
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
21+
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
2222

2323
}
2424

common/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies {
2020
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
2121
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
2222
compileOnly "org.opensearch:common-utils:${common_utils_version}"
23-
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
23+
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
2424
testImplementation "org.opensearch.test:framework:${opensearch_version}"
2525

2626
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'

memory/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies {
3535
exclude module : 'hamcrest'
3636
exclude module : 'hamcrest-core'
3737
}
38-
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
38+
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
3939
testImplementation "org.opensearch.test:framework:${opensearch_version}"
4040
testImplementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
4141
testImplementation group: 'com.google.code.gson', name: 'gson', version: '2.11.0'

ml-algorithms/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ dependencies {
4040
implementation group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0'
4141
implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0'
4242
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
43-
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
43+
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
4444
implementation (group: 'com.google.guava', name: 'guava', version: '32.1.3-jre') {
4545
exclude group: 'com.google.errorprone', module: 'error_prone_annotations'
4646
}

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
package org.opensearch.ml.action.undeploy;
77

88
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
9+
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
910

1011
import java.time.Instant;
1112
import java.util.Arrays;
1213
import java.util.List;
14+
import java.util.Map;
1315
import java.util.stream.Collectors;
1416

1517
import org.opensearch.ExceptionsHelper;
@@ -198,7 +200,19 @@ private void undeployModels(
198200
* Having this change enables a check that this edge case occurs along with having access to the model id
199201
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
200202
*/
201-
if (response.getNodes().isEmpty()) {
203+
boolean modelNotFoundInNodesCache = response.getNodes().stream().allMatch(nodeResponse -> {
204+
Map<String, String> status = nodeResponse.getModelUndeployStatus();
205+
if (status == null)
206+
return false;
207+
// Stream is used to catch all models edge case but only one is ever undeployed
208+
boolean modelCacheMissForModelIds = Arrays.stream(modelIds).allMatch(modelId -> {
209+
String modelStatus = status.get(modelId);
210+
return modelStatus != null && modelStatus.equalsIgnoreCase(NOT_FOUND);
211+
});
212+
213+
return modelCacheMissForModelIds;
214+
});
215+
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
202216
bulkSetModelIndexToUndeploy(modelIds, listener, response);
203217
return;
204218
}

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

+59
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
import static org.mockito.Mockito.verify;
1919
import static org.mockito.Mockito.when;
2020
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
21+
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
2122
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
2223

2324
import java.io.IOException;
2425
import java.util.ArrayList;
26+
import java.util.HashMap;
2527
import java.util.List;
2628
import java.util.Map;
2729

@@ -348,6 +350,63 @@ public void testHiddenModelSuccess() {
348350
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
349351
}
350352

353+
public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
354+
MLModel mlModel = MLModel
355+
.builder()
356+
.user(User.parse(USER_STRING))
357+
.modelGroupId("111")
358+
.version("111")
359+
.name(this.modelIds[0])
360+
.modelId(this.modelIds[0])
361+
.algorithm(FunctionName.BATCH_RCF)
362+
.content("content")
363+
.totalChunks(2)
364+
.isHidden(true)
365+
.build();
366+
367+
// Mock MLModel manager response
368+
doAnswer(invocation -> {
369+
ActionListener<MLModel> listener = invocation.getArgument(4);
370+
listener.onResponse(mlModel);
371+
return null;
372+
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
373+
374+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
375+
376+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
377+
378+
for (String nodeId : this.nodeIds) {
379+
Map<String, String> stats = new HashMap<>();
380+
stats.put(this.modelIds[0], NOT_FOUND);
381+
MLUndeployModelNodeResponse nodeResponse = mock(MLUndeployModelNodeResponse.class);
382+
when(nodeResponse.getModelUndeployStatus()).thenReturn(stats);
383+
responseList.add(nodeResponse);
384+
}
385+
386+
List<FailedNodeException> failuresList = new ArrayList<>();
387+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
388+
389+
doAnswer(invocation -> {
390+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
391+
listener.onResponse(nodesResponse);
392+
return null;
393+
}).when(client).execute(any(), any(), isA(ActionListener.class));
394+
395+
doAnswer(invocation -> {
396+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
397+
listener.onResponse(mock(BulkResponse.class));
398+
return null;
399+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
400+
401+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
402+
403+
transportUndeployModelsAction.doExecute(task, request, actionListener);
404+
405+
// Verify that bulk request was fired because all nodes reported "not_found"
406+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
407+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
408+
}
409+
351410
public void testHiddenModelPermissionError() {
352411
MLModel mlModel = MLModel
353412
.builder()

0 commit comments

Comments
 (0)