Skip to content

Commit caeacf7

Browse files
authored
add rate limiting for offline batch jobs, set default bulk size to 500 (opensearch-project#3116) (opensearch-project#3122)
* add rate limiting for offline batch jobs, set default bulk size to 500 * update error code to 429 for rate limiting and update logs --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent d1ff1bc commit caeacf7

File tree

12 files changed

+231
-49
lines changed

12 files changed

+231
-49
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public interface Ingestable {
1313
* @param mlBatchIngestionInput batch ingestion input data
1414
* @return successRate (0 - 100)
1515
*/
16-
default double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
16+
default double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
1717
throw new IllegalStateException("Ingest is not implemented");
1818
}
1919
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public OpenAIDataIngestion(Client client) {
3939
}
4040

4141
@Override
42-
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
42+
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
4343
List<String> sources = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
4444
if (Objects.isNull(sources) || sources.isEmpty()) {
4545
return 100;
@@ -48,13 +48,19 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
4848
boolean isSoleSource = sources.size() == 1;
4949
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
5050
for (int sourceIndex = 0; sourceIndex < sources.size(); sourceIndex++) {
51-
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
51+
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
5252
}
5353

5454
return calculateSuccessRate(successRates);
5555
}
5656

57-
private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, boolean isSoleSource) {
57+
private double ingestSingleSource(
58+
String fileId,
59+
MLBatchIngestionInput mlBatchIngestionInput,
60+
int sourceIndex,
61+
boolean isSoleSource,
62+
int bulkSize
63+
) {
5864
double successRate = 0;
5965
try {
6066
String apiKey = mlBatchIngestionInput.getCredential().get(API_KEY);
@@ -82,8 +88,8 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
8288
linesBuffer.add(line);
8389
lineCount++;
8490

85-
// Process every 100 lines
86-
if (lineCount % 100 == 0) {
91+
// Process every bulkSize lines
92+
if (lineCount % bulkSize == 0) {
8793
// Create a CompletableFuture that will be completed by the bulkResponseListener
8894
CompletableFuture<Void> future = new CompletableFuture<>();
8995
batchIngest(

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public S3DataIngestion(Client client) {
5353
}
5454

5555
@Override
56-
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
56+
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
5757
S3Client s3 = initS3Client(mlBatchIngestionInput);
5858

5959
List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
@@ -63,7 +63,7 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
6363
boolean isSoleSource = s3Uris.size() == 1;
6464
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
6565
for (int sourceIndex = 0; sourceIndex < s3Uris.size(); sourceIndex++) {
66-
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
66+
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
6767
}
6868

6969
return calculateSuccessRate(successRates);
@@ -74,7 +74,8 @@ public double ingestSingleSource(
7474
String s3Uri,
7575
MLBatchIngestionInput mlBatchIngestionInput,
7676
int sourceIndex,
77-
boolean isSoleSource
77+
boolean isSoleSource,
78+
int bulkSize
7879
) {
7980
String bucketName = getS3BucketName(s3Uri);
8081
String keyName = getS3KeyName(s3Uri);
@@ -99,8 +100,8 @@ public double ingestSingleSource(
99100
linesBuffer.add(line);
100101
lineCount++;
101102

102-
// Process every 100 lines
103-
if (lineCount % 100 == 0) {
103+
// Process every bulkSize lines
104+
if (lineCount % bulkSize == 0) {
104105
// Create a CompletableFuture that will be completed by the bulkResponseListener
105106
CompletableFuture<Void> future = new CompletableFuture<>();
106107
batchIngest(

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+50-26
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
1111
import static org.opensearch.ml.common.MLTaskState.FAILED;
1212
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
13+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE;
1314
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
1415
import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG;
1516

@@ -24,7 +25,9 @@
2425
import org.opensearch.action.support.ActionFilters;
2526
import org.opensearch.action.support.HandledTransportAction;
2627
import org.opensearch.client.Client;
28+
import org.opensearch.cluster.service.ClusterService;
2729
import org.opensearch.common.inject.Inject;
30+
import org.opensearch.common.settings.Settings;
2831
import org.opensearch.core.action.ActionListener;
2932
import org.opensearch.core.rest.RestStatus;
3033
import org.opensearch.ml.common.MLTask;
@@ -60,16 +63,19 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
6063
private final Client client;
6164
private ThreadPool threadPool;
6265
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
66+
private volatile Integer batchIngestionBulkSize;
6367

6468
@Inject
6569
public TransportBatchIngestionAction(
70+
ClusterService clusterService,
6671
TransportService transportService,
6772
ActionFilters actionFilters,
6873
Client client,
6974
MLTaskManager mlTaskManager,
7075
ThreadPool threadPool,
7176
MLModelManager mlModelManager,
72-
MLFeatureEnabledSetting mlFeatureEnabledSetting
77+
MLFeatureEnabledSetting mlFeatureEnabledSetting,
78+
Settings settings
7379
) {
7480
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
7581
this.transportService = transportService;
@@ -78,6 +84,12 @@ public TransportBatchIngestionAction(
7884
this.threadPool = threadPool;
7985
this.mlModelManager = mlModelManager;
8086
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
87+
88+
batchIngestionBulkSize = ML_COMMONS_BATCH_INGESTION_BULK_SIZE.get(settings);
89+
clusterService
90+
.getClusterSettings()
91+
.addSettingsUpdateConsumer(ML_COMMONS_BATCH_INGESTION_BULK_SIZE, it -> batchIngestionBulkSize = it);
92+
8193
}
8294

8395
@Override
@@ -131,33 +143,45 @@ protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInpu
131143
.state(MLTaskState.CREATED)
132144
.build();
133145

134-
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
135-
String taskId = response.getId();
136-
try {
137-
mlTask.setTaskId(taskId);
138-
mlTaskManager.add(mlTask);
139-
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
140-
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
141-
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
142-
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
143-
executeWithErrorHandling(() -> {
144-
double successRate = ingestable.ingest(mlBatchIngestionInput);
145-
handleSuccessRate(successRate, taskId);
146-
}, taskId);
147-
});
148-
} catch (Exception ex) {
149-
log.error("Failed in batch ingestion", ex);
150-
mlTaskManager
151-
.updateMLTask(
152-
taskId,
153-
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
154-
TASK_SEMAPHORE_TIMEOUT,
155-
true
156-
);
157-
listener.onFailure(ex);
146+
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
147+
if (exceedLimits) {
148+
String error =
149+
"Exceeded maximum limit for BATCH_INGEST tasks. To increase the limit, update the plugins.ml_commons.max_batch_ingestion_tasks setting.";
150+
log.warn(error + " in task " + mlTask.getTaskId());
151+
listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS));
152+
} else {
153+
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
154+
String taskId = response.getId();
155+
try {
156+
mlTask.setTaskId(taskId);
157+
mlTaskManager.add(mlTask);
158+
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
159+
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
160+
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
161+
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
162+
executeWithErrorHandling(() -> {
163+
double successRate = ingestable.ingest(mlBatchIngestionInput, batchIngestionBulkSize);
164+
handleSuccessRate(successRate, taskId);
165+
}, taskId);
166+
});
167+
} catch (Exception ex) {
168+
log.error("Failed in batch ingestion", ex);
169+
mlTaskManager
170+
.updateMLTask(
171+
taskId,
172+
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
173+
TASK_SEMAPHORE_TIMEOUT,
174+
true
175+
);
176+
listener.onFailure(ex);
177+
}
178+
}, exception -> {
179+
log.error("Failed to create batch ingestion task", exception);
180+
listener.onFailure(exception);
181+
}));
158182
}
159183
}, exception -> {
160-
log.error("Failed to create batch ingestion task", exception);
184+
log.error("Failed to check the maximum BATCH_INGEST Task limits", exception);
161185
listener.onFailure(exception);
162186
}));
163187
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+27
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
4141
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
4242
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
43+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS;
44+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS;
4345
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
4446
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE;
4547
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE;
@@ -107,6 +109,7 @@
107109
import org.opensearch.ml.common.MLModelGroup;
108110
import org.opensearch.ml.common.MLTask;
109111
import org.opensearch.ml.common.MLTaskState;
112+
import org.opensearch.ml.common.MLTaskType;
110113
import org.opensearch.ml.common.connector.Connector;
111114
import org.opensearch.ml.common.controller.MLController;
112115
import org.opensearch.ml.common.controller.MLRateLimiter;
@@ -177,6 +180,8 @@ public class MLModelManager {
177180
private volatile Integer maxModelPerNode;
178181
private volatile Integer maxRegisterTasksPerNode;
179182
private volatile Integer maxDeployTasksPerNode;
183+
private volatile Integer maxBatchInferenceTasks;
184+
private volatile Integer maxBatchIngestionTasks;
180185

181186
public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet
182187
.of(
@@ -232,6 +237,16 @@ public MLModelManager(
232237
clusterService
233238
.getClusterSettings()
234239
.addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it);
240+
241+
maxBatchInferenceTasks = ML_COMMONS_MAX_BATCH_INFERENCE_TASKS.get(settings);
242+
clusterService
243+
.getClusterSettings()
244+
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, it -> maxBatchInferenceTasks = it);
245+
246+
maxBatchIngestionTasks = ML_COMMONS_MAX_BATCH_INGESTION_TASKS.get(settings);
247+
clusterService
248+
.getClusterSettings()
249+
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INGESTION_TASKS, it -> maxBatchIngestionTasks = it);
235250
}
236251

237252
public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
@@ -863,6 +878,18 @@ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
863878
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
864879
}
865880

881+
/**
882+
* Check if exceed batch job task limit
883+
*
884+
* @param mlTask ML task
885+
* @param listener ActionListener if the limit is exceeded
886+
*/
887+
public void checkMaxBatchJobTask(MLTask mlTask, ActionListener<Boolean> listener) {
888+
MLTaskType taskType = mlTask.getTaskType();
889+
int maxLimit = taskType.equals(MLTaskType.BATCH_PREDICTION) ? maxBatchInferenceTasks : maxBatchIngestionTasks;
890+
mlTaskManager.checkMaxBatchJobTask(taskType, maxLimit, listener);
891+
}
892+
866893
private void updateModelRegisterStateAsDone(
867894
MLRegisterModelInput registerModelInput,
868895
String taskId,

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,10 @@ public List<Setting<?>> getSettings() {
973973
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
974974
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
975975
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED,
976-
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED
976+
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED,
977+
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS,
978+
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS,
979+
MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE
977980
);
978981
return settings;
979982
}

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ private MLCommonsSettings() {}
3434
Setting.Property.NodeScope,
3535
Setting.Property.Dynamic
3636
);
37+
38+
public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INFERENCE_TASKS = Setting
39+
.intSetting("plugins.ml_commons.max_batch_inference_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);
40+
41+
public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INGESTION_TASKS = Setting
42+
.intSetting("plugins.ml_commons.max_batch_ingestion_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);
43+
44+
public static final Setting<Integer> ML_COMMONS_BATCH_INGESTION_BULK_SIZE = Setting
45+
.intSetting("plugins.ml_commons.batch_ingestion_bulk_size", 500, 100, 100000, Setting.Property.NodeScope, Setting.Property.Dynamic);
3746
public static final Setting<Integer> ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE = Setting
3847
.intSetting("plugins.ml_commons.max_deploy_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
3948
public static final Setting<Integer> ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+27
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,33 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
253253
.lastUpdateTime(now)
254254
.async(false)
255255
.build();
256+
if (actionType.equals(ActionType.BATCH_PREDICT)) {
257+
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
258+
if (exceedLimits) {
259+
String error =
260+
"Exceeded maximum limit for BATCH_PREDICTION tasks. To increase the limit, update the plugins.ml_commons.max_batch_inference_tasks setting.";
261+
log.warn(error + " in task " + mlTask.getTaskId());
262+
listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS));
263+
} else {
264+
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
265+
}
266+
}, exception -> {
267+
log.error("Failed to check the maximum BATCH_PREDICTION Task limits", exception);
268+
listener.onFailure(exception);
269+
}));
270+
return;
271+
}
272+
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
273+
}
274+
275+
private void executePredictionByInputDataType(
276+
MLInputDataType inputDataType,
277+
String modelId,
278+
MLInput mlInput,
279+
MLTask mlTask,
280+
FunctionName functionName,
281+
ActionListener<MLTaskResponse> listener
282+
) {
256283
switch (inputDataType) {
257284
case SEARCH_QUERY:
258285
ActionListener<MLInputDataset> dataFrameActionListener = ActionListener.wrap(dataSet -> {

0 commit comments

Comments
 (0)