Skip to content

Commit e4bd41d

Browse files
opensearch-trigger-bot[bot]rbhavna
authored andcommitted
enable auto redeploy for hidden model (opensearch-project#2102) (opensearch-project#2136)
* enable auto redeploy for hidden model Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> (cherry picked from commit 9567ca5) Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 4f87254 commit e4bd41d

File tree

5 files changed

+44
-6
lines changed

5 files changed

+44
-6
lines changed

common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,32 @@ public class MLDeployModelRequest extends MLTaskRequest {
3838
private String modelId;
3939
private String[] modelNodeIds;
4040
boolean async;
41+
// This is to identify if the deploy request is initiated by user or not. During auto redeploy also, we perform deploy operation.
42+
// This field is mainly to distinguish between these two situations.
43+
private final boolean isUserInitiatedDeployRequest;
4144

4245
@Builder
43-
public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask) {
46+
public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask, boolean isUserInitiatedDeployRequest) {
4447
super(dispatchTask);
4548
this.modelId = modelId;
4649
this.modelNodeIds = modelNodeIds;
4750
this.async = async;
51+
this.isUserInitiatedDeployRequest = isUserInitiatedDeployRequest;
4852
}
4953

54+
// In this constructor, isUserInitiatedDeployRequest to always set to true. So, it can be used only when
55+
// deploy request is coming directly from the user. DO NOT use this when the
56+
// deploy call is from the code or system initiated.
5057
public MLDeployModelRequest(String modelId, boolean async) {
51-
this(modelId, null, async, true);
58+
this(modelId, null, async, true, true);
5259
}
5360

5461
public MLDeployModelRequest(StreamInput in) throws IOException {
5562
super(in);
5663
this.modelId = in.readString();
5764
this.modelNodeIds = in.readOptionalStringArray();
5865
this.async = in.readBoolean();
66+
this.isUserInitiatedDeployRequest = in.readBoolean();
5967
}
6068

6169
@Override
@@ -74,6 +82,7 @@ public void writeTo(StreamOutput out) throws IOException {
7482
out.writeString(modelId);
7583
out.writeOptionalStringArray(modelNodeIds);
7684
out.writeBoolean(async);
85+
out.writeBoolean(isUserInitiatedDeployRequest);
7786
}
7887

7988
public static MLDeployModelRequest parse(XContentParser parser, String modelId) throws IOException {
@@ -96,7 +105,7 @@ public static MLDeployModelRequest parse(XContentParser parser, String modelId)
96105
}
97106
}
98107
String[] nodeIds = nodeIdList == null ? null : nodeIdList.toArray(new String[0]);
99-
return new MLDeployModelRequest(modelId, nodeIds, false, true);
108+
return new MLDeployModelRequest(modelId, nodeIds, false, true, true);
100109
}
101110

102111
public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest) {

plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ public TransportDeployModelAction(
131131
protected void doExecute(Task task, ActionRequest request, ActionListener<MLDeployModelResponse> listener) {
132132
MLDeployModelRequest deployModelRequest = MLDeployModelRequest.fromActionRequest(request);
133133
String modelId = deployModelRequest.getModelId();
134+
Boolean isUserInitiatedDeployRequest = deployModelRequest.isUserInitiatedDeployRequest();
134135
User user = RestActionUtils.getUserContext(client);
135136
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
136137
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
@@ -143,7 +144,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
143144
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
144145
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
145146
}
146-
if (isHidden != null && isHidden) {
147+
if (!isUserInitiatedDeployRequest) {
148+
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
149+
} else if (isHidden != null && isHidden) {
147150
if (isSuperAdmin) {
148151
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
149152
} else {

plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ private void triggerModelRedeploy(ModelAutoRedeployArrangement modelAutoRedeploy
308308
ImmutableMap.of(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, Optional.ofNullable(autoRedeployRetryTimes).orElse(0) + 1)
309309
);
310310

311-
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, nodeIds, false, true);
311+
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, nodeIds, false, true, false);
312312
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, listener);
313313
}
314314

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ private void updateModelRegisterStateAsDone(
860860
void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) {
861861
String[] modelNodeIds = registerModelInput.getModelNodeIds();
862862
log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds));
863-
MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true);
863+
MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true, true);
864864
ActionListener<MLDeployModelResponse> listener = ActionListener
865865
.wrap(r -> log.debug("model deployed, response {}", r), e -> log.error("Failed to deploy model", e));
866866
client.execute(MLDeployModelAction.INSTANCE, request, listener);

plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java

+26
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ public void setup() {
173173
return null;
174174
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());
175175

176+
when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);
177+
176178
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
177179

178180
MLStat mlStat = mock(MLStat.class);
@@ -218,6 +220,30 @@ public void testDoExecute_success() {
218220
verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class));
219221
}
220222

223+
public void testDoExecute_success_not_userInitiatedRequest() {
224+
MLModel mlModel = mock(MLModel.class);
225+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
226+
doAnswer(invocation -> {
227+
ActionListener<MLModel> listener = invocation.getArgument(3);
228+
listener.onResponse(mlModel);
229+
return null;
230+
}).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class));
231+
232+
when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(false);
233+
234+
IndexResponse indexResponse = mock(IndexResponse.class);
235+
when(indexResponse.getId()).thenReturn("mockIndexId");
236+
doAnswer(invocation -> {
237+
ActionListener<IndexResponse> listener = invocation.getArgument(1);
238+
listener.onResponse(indexResponse);
239+
return null;
240+
}).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class));
241+
242+
ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class);
243+
transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener);
244+
verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class));
245+
}
246+
221247
public void testDoExecute_success_hidden_model() {
222248
transportDeployModelAction = spy(
223249
new TransportDeployModelAction(

0 commit comments

Comments
 (0)