Skip to content

Commit ba587fd

Browse files
committed
move logic after user access check
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent e6c6bfc commit ba587fd

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java

+6-15
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,12 @@ public DeleteModelTransportAction(
115115
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
116116
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.fromActionRequest(request);
117117
String modelId = mlModelDeleteRequest.getModelId();
118-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
119-
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
120-
checkDownstreamTaskBeforeDeleteModel(modelId, wrappedListener);
121-
} catch (Exception e) {
122-
log.error(e.getMessage(), e);
123-
actionListener.onFailure(e);
124-
}
125-
126-
}
127-
128-
private void doDeleteModel(String modelId, ActionListener<DeleteResponse> actionListener) {
129118
MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false, false);
130119
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent());
131120
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext);
132121
User user = RestActionUtils.getUserContext(client);
133122
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
123+
134124
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
135125
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
136126
client.get(getRequest, ActionListener.wrap(r -> {
@@ -156,7 +146,7 @@ private void doDeleteModel(String modelId, ActionListener<DeleteResponse> action
156146
);
157147
} else {
158148
if (isModelNotDeployed(mlModelState)) {
159-
deleteModel(modelId, isHidden, actionListener);
149+
checkDownstreamTaskBeforeDeleteModel(modelId, isHidden, actionListener);
160150
} else {
161151
wrappedListener
162152
.onFailure(
@@ -179,7 +169,8 @@ private void doDeleteModel(String modelId, ActionListener<DeleteResponse> action
179169
)
180170
);
181171
} else if (isModelNotDeployed(mlModelState)) {
182-
deleteModel(modelId, isHidden, actionListener);
172+
checkDownstreamTaskBeforeDeleteModel(modelId, isHidden, actionListener);
173+
;
183174
} else {
184175
wrappedListener
185176
.onFailure(
@@ -364,7 +355,7 @@ private void checkSearchPipelineBeforeDeleteModel(String modelId, ActionListener
364355

365356
}
366357

367-
private void checkDownstreamTaskBeforeDeleteModel(String modelId, ActionListener<DeleteResponse> actionListener) {
358+
private void checkDownstreamTaskBeforeDeleteModel(String modelId, Boolean isHidden, ActionListener<DeleteResponse> actionListener) {
368359
CountDownLatch countDownLatch = new CountDownLatch(3);
369360
AtomicBoolean noneBlocked = new AtomicBoolean(true);
370361
List<String> errorMessages = new ArrayList<>();
@@ -373,7 +364,7 @@ private void checkDownstreamTaskBeforeDeleteModel(String modelId, ActionListener
373364
noneBlocked.compareAndSet(true, b);
374365
if (countDownLatch.getCount() == 0) {
375366
if (noneBlocked.get()) {
376-
doDeleteModel(modelId, actionListener);
367+
deleteModel(modelId, isHidden, actionListener);
377368
} else {
378369
actionListener.onFailure(new OpenSearchStatusException(String.join(",", errorMessages), RestStatus.CONFLICT));
379370
}

plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ private GetResponse buildResponse(MLModel mlModel) throws IOException {
675675
return getResponse;
676676
}
677677

678-
private void prepare() {
678+
private void prepare() throws IOException {
679679
emptyBulkByScrollResponse = new BulkByScrollResponse(new ArrayList<>(), null);
680680
SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f);
681681
when(searchResponse.getHits()).thenReturn(hits);
@@ -708,5 +708,12 @@ private void prepare() {
708708
configDataMap = Map
709709
.of("model_id", "test_id", "list_model_id", List.of("test_list_id"), "test_map_id", Map.of("test_key", "test_map_id"));
710710
doAnswer(invocation -> new SearchRequest()).when(agentModelsSearcher).constructQueryRequest(any());
711+
712+
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false);
713+
doAnswer(invocation -> {
714+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
715+
actionListener.onResponse(getResponse);
716+
return null;
717+
}).when(client).get(any(), any());
711718
}
712719
}

0 commit comments

Comments
 (0)