Skip to content

Commit 1af1246

Browse files
ylwu-amznZhangxunmt
authored andcommitted
add feature flags for remote inference (opensearch-project#1223)
Signed-off-by: Xun Zhang <xunzh@amazon.com> Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent 07aabd8 commit 1af1246

12 files changed

+189
-23
lines changed

plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
1212
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
1313
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
14+
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1415

1516
import java.time.Instant;
1617
import java.util.ArrayList;
@@ -52,6 +53,7 @@
5253
import org.opensearch.ml.engine.ModelHelper;
5354
import org.opensearch.ml.helper.ModelAccessControlHelper;
5455
import org.opensearch.ml.model.MLModelManager;
56+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
5557
import org.opensearch.ml.stats.MLNodeLevelStat;
5658
import org.opensearch.ml.stats.MLStats;
5759
import org.opensearch.ml.task.MLTaskDispatcher;
@@ -83,6 +85,7 @@ public class TransportDeployModelAction extends HandledTransportAction<ActionReq
8385

8486
private volatile boolean allowCustomDeploymentPlan;
8587
private ModelAccessControlHelper modelAccessControlHelper;
88+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
8689

8790
@Inject
8891
public TransportDeployModelAction(
@@ -99,7 +102,8 @@ public TransportDeployModelAction(
99102
MLModelManager mlModelManager,
100103
MLStats mlStats,
101104
Settings settings,
102-
ModelAccessControlHelper modelAccessControlHelper
105+
ModelAccessControlHelper modelAccessControlHelper,
106+
MLFeatureEnabledSetting mlFeatureEnabledSetting
103107
) {
104108
super(MLDeployModelAction.NAME, transportService, actionFilters, MLDeployModelRequest::new);
105109
this.transportService = transportService;
@@ -114,6 +118,7 @@ public TransportDeployModelAction(
114118
this.mlModelManager = mlModelManager;
115119
this.mlStats = mlStats;
116120
this.modelAccessControlHelper = modelAccessControlHelper;
121+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
117122
allowCustomDeploymentPlan = ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings);
118123
clusterService
119124
.getClusterSettings()
@@ -130,6 +135,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
130135
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
131136
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
132137
FunctionName functionName = mlModel.getAlgorithm();
138+
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
139+
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
140+
}
133141
modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
134142
if (!access) {
135143
listener

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+15-4
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
import org.opensearch.ml.rest.RestMLUpdateModelGroupAction;
153153
import org.opensearch.ml.rest.RestMLUploadModelChunkAction;
154154
import org.opensearch.ml.settings.MLCommonsSettings;
155+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
155156
import org.opensearch.ml.stats.MLClusterLevelStat;
156157
import org.opensearch.ml.stats.MLNodeLevelStat;
157158
import org.opensearch.ml.stats.MLStat;
@@ -221,6 +222,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
221222

222223
private ConnectorAccessControlHelper connectorAccessControlHelper;
223224

225+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
226+
224227
@Override
225228
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
226229
return ImmutableList
@@ -330,6 +333,8 @@ public Collection<Object> createComponents(
330333
mlInputDatasetHandler = new MLInputDatasetHandler(client);
331334
modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings);
332335
connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings);
336+
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);
337+
333338
mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper);
334339

335340
MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper);
@@ -436,6 +441,7 @@ public Collection<Object> createComponents(
436441
mlExecuteTaskRunner,
437442
modelAccessControlHelper,
438443
connectorAccessControlHelper,
444+
mlFeatureEnabledSetting,
439445
mlSearchHandler,
440446
mlTaskDispatcher,
441447
mlModelChunkUploader,
@@ -460,7 +466,7 @@ public List<RestHandler> getRestHandlers(
460466
RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils, xContentRegistry);
461467
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction();
462468
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction();
463-
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager);
469+
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting);
464470
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction();
465471
RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction();
466472
RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction();
@@ -469,7 +475,11 @@ public List<RestHandler> getRestHandlers(
469475
RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction();
470476
RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction();
471477
RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService);
472-
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings);
478+
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(
479+
clusterService,
480+
settings,
481+
mlFeatureEnabledSetting
482+
);
473483
RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction();
474484
RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
475485
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
@@ -478,7 +488,7 @@ public List<RestHandler> getRestHandlers(
478488
RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction();
479489
RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction();
480490
RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction();
481-
RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction();
491+
RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting);
482492
RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction();
483493
RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction();
484494
RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction();
@@ -613,7 +623,8 @@ public List<Setting<?>> getSettings() {
613623
MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED,
614624
MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX,
615625
MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES,
616-
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES
626+
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES,
627+
MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED
617628
);
618629
return settings;
619630
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
10+
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1011

1112
import java.io.IOException;
1213
import java.util.List;
@@ -17,6 +18,7 @@
1718
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
1819
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
1920
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
21+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
2022
import org.opensearch.rest.BaseRestHandler;
2123
import org.opensearch.rest.RestRequest;
2224
import org.opensearch.rest.action.RestToXContentListener;
@@ -26,11 +28,15 @@
2628

2729
public class RestMLCreateConnectorAction extends BaseRestHandler {
2830
private static final String ML_CREATE_CONNECTOR_ACTION = "ml_create_connector_action";
31+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
2932

3033
/**
31-
* Constructor *
34+
* Constructor
35+
* @param mlFeatureEnabledSetting
3236
*/
33-
public RestMLCreateConnectorAction() {}
37+
public RestMLCreateConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
38+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
39+
}
3440

3541
@Override
3642
public String getName() {
@@ -56,6 +62,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
5662
*/
5763
@VisibleForTesting
5864
MLCreateConnectorRequest getRequest(RestRequest request) throws IOException {
65+
if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
66+
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
67+
}
68+
if (!request.hasContent()) {
69+
throw new IOException("Create Connector request has empty body");
70+
}
5971
XContentParser parser = request.contentParser();
6072
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
6173
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.parse(parser);

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
10+
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1011
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
1112
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
1213
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
@@ -29,6 +30,7 @@
2930
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3031
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
3132
import org.opensearch.ml.model.MLModelManager;
33+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
3234
import org.opensearch.rest.BaseRestHandler;
3335
import org.opensearch.rest.BytesRestResponse;
3436
import org.opensearch.rest.RestRequest;
@@ -45,11 +47,14 @@ public class RestMLPredictionAction extends BaseRestHandler {
4547

4648
private MLModelManager modelManager;
4749

50+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
51+
4852
/**
4953
* Constructor
5054
*/
51-
public RestMLPredictionAction(MLModelManager modelManager) {
55+
public RestMLPredictionAction(MLModelManager modelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
5256
this.modelManager = modelManager;
57+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
5358
}
5459

5560
@Override
@@ -117,6 +122,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
117122
*/
118123
@VisibleForTesting
119124
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
125+
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
126+
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
127+
}
120128
XContentParser parser = request.contentParser();
121129
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
122130
MLInput mlInput = MLInput.parse(parser, algorithm);

plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java

+12-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1010
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
11+
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1112
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL;
1213
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
1314
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_VERSION;
@@ -20,9 +21,11 @@
2021
import org.opensearch.cluster.service.ClusterService;
2122
import org.opensearch.common.settings.Settings;
2223
import org.opensearch.core.xcontent.XContentParser;
24+
import org.opensearch.ml.common.FunctionName;
2325
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
2426
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
2527
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
28+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
2629
import org.opensearch.rest.BaseRestHandler;
2730
import org.opensearch.rest.RestRequest;
2831
import org.opensearch.rest.action.RestToXContentListener;
@@ -33,20 +36,24 @@
3336
public class RestMLRegisterModelAction extends BaseRestHandler {
3437
private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action";
3538
private volatile boolean isModelUrlAllowed;
39+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
3640

3741
/**
3842
* Constructor
3943
*/
40-
public RestMLRegisterModelAction() {}
44+
public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
45+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
46+
}
4147

4248
/**
4349
* Constructor
4450
* @param clusterService cluster service
4551
* @param settings settings
4652
*/
47-
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings) {
53+
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
4854
isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings);
4955
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it);
56+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
5057
}
5158

5259
@Override
@@ -93,6 +100,9 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
93100
XContentParser parser = request.contentParser();
94101
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
95102
MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel);
103+
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
104+
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
105+
}
96106
if (mlInput.getUrl() != null && !isModelUrlAllowed) {
97107
throw new IllegalArgumentException(
98108
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models."

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+4
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ private MLCommonsSettings() {}
108108
Setting.Property.Dynamic
109109
);
110110

111+
// This setting is to enable/disable Create Connector API and Register/Deploy/Predict Model APIs for remote models
112+
public static final Setting<Boolean> ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting
113+
.boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
114+
111115
public static final Setting<Boolean> ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting
112116
.boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
113117

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
*
3+
* * Copyright OpenSearch Contributors
4+
* * SPDX-License-Identifier: Apache-2.0
5+
*
6+
*/
7+
8+
package org.opensearch.ml.settings;
9+
10+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
11+
12+
import org.opensearch.cluster.service.ClusterService;
13+
import org.opensearch.common.settings.Settings;
14+
15+
public class MLFeatureEnabledSetting {
16+
17+
private volatile Boolean isRemoteInferenceEnabled;
18+
19+
public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
20+
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
21+
clusterService
22+
.getClusterSettings()
23+
.addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it);
24+
}
25+
26+
/**
27+
* Whether the remote inference feature is enabled. If disabled, APIs in ml-commons will block remote inference.
28+
* @return whether Remote Inference is enabled.
29+
*/
30+
public boolean isRemoteInferenceEnabled() {
31+
return isRemoteInferenceEnabled;
32+
}
33+
34+
}

plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
public class MLExceptionUtils {
1919

2020
public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: ";
21+
public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG =
22+
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
2123

2224
public static String getRootCauseMessage(final Throwable throwable) {
2325
String message = ExceptionUtils.getRootCauseMessage(throwable);

0 commit comments

Comments
 (0)