Skip to content

Commit 431c31b

Browse files
Undeploy models with no WorkerNodes (opensearch-project#3380) (opensearch-project#3447)
* undeploy models with no WorkerNodes This commit aims to undeploy modelIds that have no nodes associated to them so as to keep the intention of undeploy truthful. Signed-off-by: Brian Flores <iflorbri@amazon.com> # Conflicts: # plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java * Exit early when no nodes service the model Now when entering this method its guaranteed to write to index first before sending back the MLUndeploy response. And will also send back a exception if the write back fails Signed-off-by: Brian Flores <iflorbri@amazon.com> * add UTs for undeploy stale model index fix Added UTs for the 2 scenarios 1. Check that the bulk operation occured when no nodes are returned from the Undeploy response is , 2. Check that the bulk operation did not occur when there are nodes that have found the model within their cache. Signed-off-by: Brian Flores <iflorbri@amazon.com> * update code change with comment explaining the change Signed-off-by: Brian Flores <iflorbri@amazon.com> * add context stash/restore to write operation Signed-off-by: Brian Flores <iflorbri@amazon.com> * Apply spotless Signed-off-by: Brian Flores <iflorbri@amazon.com> * Add better logging to write request Signed-off-by: Brian Flores <iflorbri@amazon.com> * wrap exception into 5xx Signed-off-by: Brian Flores <iflorbri@amazon.com> * adapts undeploy code change to multi-tenancy feature Signed-off-by: Brian Flores <iflorbri@amazon.com> * applies spotless Signed-off-by: Brian Flores <iflorbri@amazon.com> --------- Signed-off-by: Brian Flores <iflorbri@amazon.com> (cherry picked from commit 18bcaae) Co-authored-by: Brian Flores <iflorbri@amazon.com>
1 parent aadc422 commit 431c31b

File tree

2 files changed

+226
-2
lines changed

2 files changed

+226
-2
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

+74-2
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,22 @@
77

88
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
99

10+
import java.time.Instant;
1011
import java.util.Arrays;
1112
import java.util.List;
1213
import java.util.stream.Collectors;
1314

1415
import org.opensearch.ExceptionsHelper;
1516
import org.opensearch.OpenSearchStatusException;
1617
import org.opensearch.action.ActionRequest;
18+
import org.opensearch.action.bulk.BulkRequest;
19+
import org.opensearch.action.bulk.BulkResponse;
1720
import org.opensearch.action.search.SearchRequest;
1821
import org.opensearch.action.search.SearchResponse;
1922
import org.opensearch.action.support.ActionFilters;
2023
import org.opensearch.action.support.HandledTransportAction;
24+
import org.opensearch.action.support.WriteRequest;
25+
import org.opensearch.action.update.UpdateRequest;
2126
import org.opensearch.client.Client;
2227
import org.opensearch.cluster.service.ClusterService;
2328
import org.opensearch.common.inject.Inject;
@@ -33,9 +38,11 @@
3338
import org.opensearch.index.query.TermsQueryBuilder;
3439
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
3540
import org.opensearch.ml.common.MLModel;
41+
import org.opensearch.ml.common.model.MLModelState;
3642
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
3743
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
3844
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
45+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
3946
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
4047
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
4148
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
@@ -57,6 +64,7 @@
5764
import org.opensearch.transport.TransportService;
5865

5966
import com.google.common.annotations.VisibleForTesting;
67+
import com.google.common.collect.ImmutableMap;
6068

6169
import lombok.extern.log4j.Log4j2;
6270

@@ -179,11 +187,75 @@ private void undeployModels(
179187
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);
180188
mlUndeployModelNodesRequest.setTenantId(tenantId);
181189

182-
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
183-
listener.onResponse(new MLUndeployModelsResponse(r));
190+
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(response -> {
191+
/*
192+
* The method TransportUndeployModelsAction.processUndeployModelResponseAndUpdate(...) performs
193+
* undeploy action of models by removing the models from the nodes cache and updating the index when it's able to find it.
194+
*
195+
* The problem becomes when the models index is incorrect and no node(s) are servicing the model. This results in
196+
* `{}` responses (on undeploy action), with no update to the model index thus, causing incorrect model state status.
197+
*
198+
* Having this change enables a check that this edge case occurs along with having access to the model id
199+
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
200+
*/
201+
if (response.getNodes().isEmpty()) {
202+
bulkSetModelIndexToUndeploy(modelIds, listener, response);
203+
return;
204+
}
205+
listener.onResponse(new MLUndeployModelsResponse(response));
184206
}, listener::onFailure));
185207
}
186208

209+
private void bulkSetModelIndexToUndeploy(
210+
String[] modelIds,
211+
ActionListener<MLUndeployModelsResponse> listener,
212+
MLUndeployModelNodesResponse response
213+
) {
214+
BulkRequest bulkUpdateRequest = new BulkRequest();
215+
for (String modelId : modelIds) {
216+
UpdateRequest updateRequest = new UpdateRequest();
217+
218+
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
219+
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
220+
221+
builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
222+
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
223+
224+
builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
225+
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
226+
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
227+
bulkUpdateRequest.add(updateRequest);
228+
}
229+
230+
bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
231+
log.info("No nodes running these models: {}", Arrays.toString(modelIds));
232+
233+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
234+
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
235+
.runBefore(listener, () -> threadContext.restore());
236+
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
237+
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
238+
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
239+
}, e -> {
240+
String modelsNotFoundMessage = String
241+
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
242+
log.error(modelsNotFoundMessage, e);
243+
244+
OpenSearchStatusException exception = new OpenSearchStatusException(
245+
modelsNotFoundMessage + e.getMessage(),
246+
RestStatus.INTERNAL_SERVER_ERROR
247+
);
248+
listenerWithContextRestoration.onFailure(exception);
249+
});
250+
251+
client.bulk(bulkUpdateRequest, bulkResponseListener);
252+
} catch (Exception e) {
253+
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
254+
listener.onFailure(e);
255+
}
256+
257+
}
258+
187259
private void validateAccess(String modelId, String tenantId, ActionListener<Boolean> listener) {
188260
User user = RestActionUtils.getUserContext(client);
189261
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

+152
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
import static org.mockito.Mockito.doReturn;
1414
import static org.mockito.Mockito.doThrow;
1515
import static org.mockito.Mockito.mock;
16+
import static org.mockito.Mockito.never;
1617
import static org.mockito.Mockito.spy;
1718
import static org.mockito.Mockito.verify;
1819
import static org.mockito.Mockito.when;
20+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
1921
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
2022

2123
import java.io.IOException;
2224
import java.util.ArrayList;
2325
import java.util.List;
26+
import java.util.Map;
2427

2528
import org.junit.Before;
2629
import org.junit.Rule;
@@ -29,7 +32,10 @@
2932
import org.mockito.Mock;
3033
import org.mockito.MockitoAnnotations;
3134
import org.opensearch.action.FailedNodeException;
35+
import org.opensearch.action.bulk.BulkRequest;
36+
import org.opensearch.action.bulk.BulkResponse;
3237
import org.opensearch.action.support.ActionFilters;
38+
import org.opensearch.action.update.UpdateRequest;
3339
import org.opensearch.client.Client;
3440
import org.opensearch.cluster.ClusterName;
3541
import org.opensearch.cluster.service.ClusterService;
@@ -42,6 +48,7 @@
4248
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
4349
import org.opensearch.ml.common.FunctionName;
4450
import org.opensearch.ml.common.MLModel;
51+
import org.opensearch.ml.common.model.MLModelState;
4552
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
4653
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
4754
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
@@ -172,6 +179,129 @@ public void setup() throws IOException {
172179
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
173180
}
174181

182+
public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
183+
String modelId = "someModelId";
184+
MLModel mlModel = MLModel
185+
.builder()
186+
.user(User.parse(USER_STRING))
187+
.modelGroupId("111")
188+
.version("111")
189+
.name("Test Model")
190+
.modelId(modelId)
191+
.algorithm(FunctionName.BATCH_RCF)
192+
.content("content")
193+
.totalChunks(2)
194+
.isHidden(true)
195+
.build();
196+
197+
doAnswer(invocation -> {
198+
ActionListener<MLModel> listener = invocation.getArgument(4);
199+
listener.onResponse(mlModel);
200+
return null;
201+
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
202+
203+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
204+
205+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
206+
List<FailedNodeException> failuresList = new ArrayList<>();
207+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
208+
209+
// Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
210+
doAnswer(invocation -> {
211+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
212+
listener.onResponse(nodesResponse);
213+
return null;
214+
}).when(client).execute(any(), any(), isA(ActionListener.class));
215+
216+
ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
217+
218+
// mock the bulk response that can be captured for inspecting the contents of the write to index
219+
doAnswer(invocation -> {
220+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
221+
BulkResponse bulkResponse = mock(BulkResponse.class);
222+
when(bulkResponse.hasFailures()).thenReturn(false);
223+
listener.onResponse(bulkResponse);
224+
return null;
225+
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
226+
227+
String[] modelIds = new String[] { modelId };
228+
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
229+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
230+
231+
transportUndeployModelsAction.doExecute(task, request, actionListener);
232+
233+
BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue();
234+
assertEquals(1, capturedBulkRequest.numberOfActions());
235+
UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0);
236+
237+
@SuppressWarnings("unchecked")
238+
Map<String, Object> updateDoc = updateRequest.doc().sourceAsMap();
239+
String modelIdFromBulkRequest = updateRequest.id();
240+
String indexNameFromBulkRequest = updateRequest.index();
241+
242+
assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest);
243+
assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest);
244+
245+
assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD));
246+
assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD));
247+
assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD));
248+
assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD));
249+
assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD));
250+
251+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
252+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
253+
}
254+
255+
public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() {
256+
String modelId = "someModelId";
257+
MLModel mlModel = MLModel
258+
.builder()
259+
.user(User.parse(USER_STRING))
260+
.modelGroupId("111")
261+
.version("111")
262+
.name("Test Model")
263+
.modelId(modelId)
264+
.algorithm(FunctionName.BATCH_RCF)
265+
.content("content")
266+
.totalChunks(2)
267+
.isHidden(true)
268+
.build();
269+
270+
doAnswer(invocation -> {
271+
ActionListener<MLModel> listener = invocation.getArgument(4);
272+
listener.onResponse(mlModel);
273+
return null;
274+
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
275+
276+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
277+
278+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
279+
responseList.add(mock(MLUndeployModelNodeResponse.class));
280+
responseList.add(mock(MLUndeployModelNodeResponse.class));
281+
List<FailedNodeException> failuresList = new ArrayList<>();
282+
failuresList.add(mock(FailedNodeException.class));
283+
failuresList.add(mock(FailedNodeException.class));
284+
285+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
286+
287+
// Send back a response with nodes associated to the model
288+
doAnswer(invocation -> {
289+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
290+
listener.onResponse(nodesResponse);
291+
return null;
292+
}).when(client).execute(any(), any(), isA(ActionListener.class));
293+
294+
String[] modelIds = new String[] { modelId };
295+
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
296+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
297+
298+
transportUndeployModelsAction.doExecute(task, request, actionListener);
299+
300+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
301+
// Check that no bulk write occurred Since there were nodes servicing the model
302+
verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class));
303+
}
304+
175305
public void testHiddenModelSuccess() {
176306
MLModel mlModel = MLModel
177307
.builder()
@@ -194,16 +324,28 @@ public void testHiddenModelSuccess() {
194324
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
195325
List<FailedNodeException> failuresList = new ArrayList<>();
196326
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
327+
197328
doAnswer(invocation -> {
198329
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
199330
listener.onResponse(response);
200331
return null;
201332
}).when(client).execute(any(), any(), isA(ActionListener.class));
202333

334+
// Mock the client.bulk call
335+
doAnswer(invocation -> {
336+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
337+
BulkResponse bulkResponse = mock(BulkResponse.class);
338+
when(bulkResponse.hasFailures()).thenReturn(false);
339+
listener.onResponse(bulkResponse);
340+
return null;
341+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
342+
203343
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
204344
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
205345
transportUndeployModelsAction.doExecute(task, request, actionListener);
346+
206347
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
348+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
207349
}
208350

209351
public void testHiddenModelPermissionError() {
@@ -257,9 +399,19 @@ public void testDoExecute() {
257399
listener.onResponse(response);
258400
return null;
259401
}).when(client).execute(any(), any(), isA(ActionListener.class));
402+
// Mock the client.bulk call
403+
doAnswer(invocation -> {
404+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
405+
BulkResponse bulkResponse = mock(BulkResponse.class);
406+
when(bulkResponse.hasFailures()).thenReturn(false);
407+
listener.onResponse(bulkResponse);
408+
return null;
409+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
410+
260411
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
261412
transportUndeployModelsAction.doExecute(task, request, actionListener);
262413
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
414+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
263415
}
264416

265417
public void testDoExecute_modelAccessControl_notEnabled() {

0 commit comments

Comments
 (0)