Skip to content

Commit 9e4844f

Browse files
add feature flag for offline batch ingestion (opensearch-project#2982) (opensearch-project#2990)
* add feature flag for offline batch ingestion Signed-off-by: Xun Zhang <xunzh@amazon.com> * add feature flag for offline batch inference Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com> (cherry picked from commit 107b916) Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent 3865ea8 commit 9e4844f

13 files changed

+171
-13
lines changed

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.opensearch.ml.common.MLTaskState.FAILED;
1212
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
1313
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
14+
import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG;
1415

1516
import java.time.Instant;
1617
import java.util.List;
@@ -35,6 +36,7 @@
3536
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
3637
import org.opensearch.ml.engine.MLEngineClassLoader;
3738
import org.opensearch.ml.engine.ingest.Ingestable;
39+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
3840
import org.opensearch.ml.task.MLTaskManager;
3941
import org.opensearch.ml.utils.MLExceptionUtils;
4042
import org.opensearch.tasks.Task;
@@ -55,27 +57,33 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
5557
MLTaskManager mlTaskManager;
5658
private final Client client;
5759
private ThreadPool threadPool;
60+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
5861

5962
@Inject
6063
public TransportBatchIngestionAction(
6164
TransportService transportService,
6265
ActionFilters actionFilters,
6366
Client client,
6467
MLTaskManager mlTaskManager,
65-
ThreadPool threadPool
68+
ThreadPool threadPool,
69+
MLFeatureEnabledSetting mlFeatureEnabledSetting
6670
) {
6771
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
6872
this.transportService = transportService;
6973
this.client = client;
7074
this.mlTaskManager = mlTaskManager;
7175
this.threadPool = threadPool;
76+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
7277
}
7378

7479
@Override
7580
protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatchIngestionResponse> listener) {
7681
MLBatchIngestionRequest mlBatchIngestionRequest = MLBatchIngestionRequest.fromActionRequest(request);
7782
MLBatchIngestionInput mlBatchIngestionInput = mlBatchIngestionRequest.getMlBatchIngestionInput();
7883
try {
84+
if (!mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()) {
85+
throw new IllegalStateException(OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG);
86+
}
7987
validateBatchIngestInput(mlBatchIngestionInput);
8088
MLTask mlTask = MLTask
8189
.builder()

plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java

+10-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
1010
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
1111
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
12+
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
1213
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
1314

1415
import java.util.HashMap;
@@ -51,8 +52,8 @@
5152
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
5253
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
5354
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
54-
import org.opensearch.ml.model.MLModelCacheHelper;
5555
import org.opensearch.ml.model.MLModelManager;
56+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
5657
import org.opensearch.ml.task.MLTaskManager;
5758
import org.opensearch.script.ScriptService;
5859
import org.opensearch.tasks.Task;
@@ -74,7 +75,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action
7475
MLModelManager mlModelManager;
7576

7677
MLTaskManager mlTaskManager;
77-
MLModelCacheHelper modelCacheHelper;
78+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
7879

7980
@Inject
8081
public CancelBatchJobTransportAction(
@@ -87,7 +88,8 @@ public CancelBatchJobTransportAction(
8788
ConnectorAccessControlHelper connectorAccessControlHelper,
8889
EncryptorImpl encryptor,
8990
MLTaskManager mlTaskManager,
90-
MLModelManager mlModelManager
91+
MLModelManager mlModelManager,
92+
MLFeatureEnabledSetting mlFeatureEnabledSetting
9193
) {
9294
super(MLCancelBatchJobAction.NAME, transportService, actionFilters, MLCancelBatchJobRequest::new);
9395
this.client = client;
@@ -98,6 +100,7 @@ public CancelBatchJobTransportAction(
98100
this.encryptor = encryptor;
99101
this.mlTaskManager = mlTaskManager;
100102
this.mlModelManager = mlModelManager;
103+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
101104
}
102105

103106
@Override
@@ -116,6 +119,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCanc
116119
MLTask mlTask = MLTask.parse(parser);
117120

118121
// check if function is remote and task is of type batch prediction
122+
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION
123+
&& !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
124+
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
125+
}
119126
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) {
120127
processRemoteBatchPrediction(mlTask, actionListener);
121128
} else {

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
2121
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
2222
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
23+
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
2324
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
2425
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
2526

@@ -68,8 +69,8 @@
6869
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
6970
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
7071
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
71-
import org.opensearch.ml.model.MLModelCacheHelper;
7272
import org.opensearch.ml.model.MLModelManager;
73+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
7374
import org.opensearch.ml.task.MLTaskManager;
7475
import org.opensearch.script.ScriptService;
7576
import org.opensearch.tasks.Task;
@@ -91,7 +92,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
9192
MLModelManager mlModelManager;
9293

9394
MLTaskManager mlTaskManager;
94-
MLModelCacheHelper modelCacheHelper;
95+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
9596

9697
volatile List<String> remoteJobStatusFields;
9798
volatile Pattern remoteJobCompletedStatusRegexPattern;
@@ -111,6 +112,7 @@ public GetTaskTransportAction(
111112
EncryptorImpl encryptor,
112113
MLTaskManager mlTaskManager,
113114
MLModelManager mlModelManager,
115+
MLFeatureEnabledSetting mlFeatureEnabledSetting,
114116
Settings settings
115117
) {
116118
super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new);
@@ -122,6 +124,7 @@ public GetTaskTransportAction(
122124
this.encryptor = encryptor;
123125
this.mlTaskManager = mlTaskManager;
124126
this.mlModelManager = mlModelManager;
127+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
125128

126129
remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings);
127130
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it);
@@ -178,6 +181,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
178181
MLTask mlTask = MLTask.parse(parser);
179182

180183
// check if function is remote and task is of type batch prediction
184+
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION
185+
&& !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
186+
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
187+
}
181188
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) {
182189
processRemoteBatchPrediction(mlTask, taskId, actionListener);
183190
} else {

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,9 @@ public List<Setting<?>> getSettings() {
971971
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
972972
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
973973
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
974-
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED
974+
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
975+
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED,
976+
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED
975977
);
976978
return settings;
977979
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import com.google.common.collect.ImmutableList;
2525

2626
public class RestMLGetTaskAction extends BaseRestHandler {
27-
private static final String ML_GET_Task_ACTION = "ml_get_task_action";
27+
private static final String ML_GET_TASK_ACTION = "ml_get_task_action";
2828

2929
/**
3030
* Constructor
@@ -33,7 +33,7 @@ public RestMLGetTaskAction() {}
3333

3434
@Override
3535
public String getName() {
36-
return ML_GET_Task_ACTION;
36+
return ML_GET_TASK_ACTION;
3737
}
3838

3939
@Override

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
10+
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
1011
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
1112
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1213
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
@@ -131,6 +132,8 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
131132
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
132133
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
133134
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
135+
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
136+
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
134137
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
135138
throw new IllegalArgumentException("Wrong action type in the rest request path!");
136139
}

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+6
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ private MLCommonsSettings() {}
136136
public static final Setting<Boolean> ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting
137137
.boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
138138

139+
public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED = Setting
140+
.boolSetting("plugins.ml_commons.offline_batch_ingestion_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
141+
142+
public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED = Setting
143+
.boolSetting("plugins.ml_commons.offline_batch_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
144+
139145
public static final Setting<List<String>> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting
140146
.listSetting(
141147
"plugins.ml_commons.trusted_connector_endpoints_regex",

plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java

+27
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
1212
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
1313
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
14+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED;
15+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED;
1416
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
1517

1618
import java.util.concurrent.atomic.AtomicBoolean;
@@ -27,13 +29,17 @@ public class MLFeatureEnabledSetting {
2729
private volatile AtomicBoolean isConnectorPrivateIpEnabled;
2830

2931
private volatile Boolean isControllerEnabled;
32+
private volatile Boolean isBatchIngestionEnabled;
33+
private volatile Boolean isBatchInferenceEnabled;
3034

3135
public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
3236
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
3337
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
3438
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);
3539
isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings));
3640
isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings);
41+
isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings);
42+
isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings);
3743

3844
clusterService
3945
.getClusterSettings()
@@ -46,6 +52,12 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
4652
.getClusterSettings()
4753
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it));
4854
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it);
55+
clusterService
56+
.getClusterSettings()
57+
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, it -> isBatchIngestionEnabled = it);
58+
clusterService
59+
.getClusterSettings()
60+
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, it -> isBatchInferenceEnabled = it);
4961
}
5062

5163
/**
@@ -84,4 +96,19 @@ public Boolean isControllerEnabled() {
8496
return isControllerEnabled;
8597
}
8698

99+
/**
100+
* Whether the offline batch ingestion is enabled. If disabled, APIs in ml-commons will block offline batch ingestion.
101+
* @return whether the feature is enabled.
102+
*/
103+
public Boolean isOfflineBatchIngestionEnabled() {
104+
return isBatchIngestionEnabled;
105+
}
106+
107+
/**
108+
* Whether the offline batch inference is enabled. If disabled, APIs in ml-commons will block offline batch inference.
109+
* @return whether the feature is enabled.
110+
*/
111+
public Boolean isOfflineBatchInferenceEnabled() {
112+
return isBatchInferenceEnabled;
113+
}
87114
}

plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java

+4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ public class MLExceptionUtils {
2222
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
2323
public static final String LOCAL_MODEL_DISABLED_ERR_MSG =
2424
"Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.";
25+
public static final String BATCH_INFERENCE_DISABLED_ERR_MSG =
26+
"Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.";
2527
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
2628
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";
2729
public static final String CONTROLLER_DISABLED_ERR_MSG =
2830
"Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.";
31+
public static final String OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG =
32+
"Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.";
2933

3034
public static String getRootCauseMessage(final Throwable throwable) {
3135
String message = ExceptionUtils.getRootCauseMessage(throwable);

plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java

+25-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
4747
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
4848
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
49+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
4950
import org.opensearch.ml.task.MLTaskManager;
5051
import org.opensearch.tasks.Task;
5152
import org.opensearch.test.OpenSearchTestCase;
@@ -73,6 +74,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
7374
ThreadPool threadPool;
7475
@Mock
7576
ExecutorService executorService;
77+
@Mock
78+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
7679

7780
private TransportBatchIngestionAction batchAction;
7881
private MLBatchIngestionInput batchInput;
@@ -81,7 +84,14 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
8184
@Before
8285
public void setup() {
8386
MockitoAnnotations.openMocks(this);
84-
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool);
87+
batchAction = new TransportBatchIngestionAction(
88+
transportService,
89+
actionFilters,
90+
client,
91+
mlTaskManager,
92+
threadPool,
93+
mlFeatureEnabledSetting
94+
);
8595

8696
Map<String, Object> fieldMap = new HashMap<>();
8797
fieldMap.put("chapter", "$.content[0]");
@@ -106,6 +116,8 @@ public void setup() {
106116
.dataSources(dataSource)
107117
.build();
108118
when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput);
119+
120+
when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(true);
109121
}
110122

111123
public void test_doExecute_success() {
@@ -181,6 +193,18 @@ public void test_doExecute_handleSuccessRate0() {
181193
);
182194
}
183195

196+
public void test_doExecute_batchIngestionDisabled() {
197+
when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(false);
198+
batchAction.doExecute(task, mlBatchIngestionRequest, actionListener);
199+
200+
ArgumentCaptor<IllegalStateException> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
201+
verify(actionListener).onFailure(argumentCaptor.capture());
202+
assertEquals(
203+
"Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.",
204+
argumentCaptor.getValue().getMessage()
205+
);
206+
}
207+
184208
public void test_doExecute_noDataSource() {
185209
MLBatchIngestionInput batchInput = MLBatchIngestionInput
186210
.builder()

0 commit comments

Comments
 (0)