Skip to content

Commit 2bb8c75

Browse files
committed
add rate limiting for offline batch jobs, set default bulk size to 500
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 5b982c4 commit 2bb8c75

File tree

11 files changed

+226
-49
lines changed

11 files changed

+226
-49
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/Ingestable.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public interface Ingestable {
1313
* @param mlBatchIngestionInput batch ingestion input data
1414
* @return successRate (0 - 100)
1515
*/
16-
default double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
16+
default double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
1717
throw new IllegalStateException("Ingest is not implemented");
1818
}
1919
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public OpenAIDataIngestion(Client client) {
3939
}
4040

4141
@Override
42-
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
42+
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
4343
List<String> sources = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
4444
if (Objects.isNull(sources) || sources.isEmpty()) {
4545
return 100;
@@ -48,13 +48,19 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
4848
boolean isSoleSource = sources.size() == 1;
4949
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
5050
for (int sourceIndex = 0; sourceIndex < sources.size(); sourceIndex++) {
51-
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
51+
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
5252
}
5353

5454
return calculateSuccessRate(successRates);
5555
}
5656

57-
private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, boolean isSoleSource) {
57+
private double ingestSingleSource(
58+
String fileId,
59+
MLBatchIngestionInput mlBatchIngestionInput,
60+
int sourceIndex,
61+
boolean isSoleSource,
62+
int bulkSize
63+
) {
5864
double successRate = 0;
5965
try {
6066
String apiKey = mlBatchIngestionInput.getCredential().get(API_KEY);
@@ -82,8 +88,8 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
8288
linesBuffer.add(line);
8389
lineCount++;
8490

85-
// Process every 100 lines
86-
if (lineCount % 100 == 0) {
91+
// Process every bulkSize lines
92+
if (lineCount % bulkSize == 0) {
8793
// Create a CompletableFuture that will be completed by the bulkResponseListener
8894
CompletableFuture<Void> future = new CompletableFuture<>();
8995
batchIngest(

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public S3DataIngestion(Client client) {
5353
}
5454

5555
@Override
56-
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
56+
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
5757
S3Client s3 = initS3Client(mlBatchIngestionInput);
5858

5959
List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
@@ -63,7 +63,7 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
6363
boolean isSoleSource = s3Uris.size() == 1;
6464
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
6565
for (int sourceIndex = 0; sourceIndex < s3Uris.size(); sourceIndex++) {
66-
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
66+
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
6767
}
6868

6969
return calculateSuccessRate(successRates);
@@ -74,7 +74,8 @@ public double ingestSingleSource(
7474
String s3Uri,
7575
MLBatchIngestionInput mlBatchIngestionInput,
7676
int sourceIndex,
77-
boolean isSoleSource
77+
boolean isSoleSource,
78+
int bulkSize
7879
) {
7980
String bucketName = getS3BucketName(s3Uri);
8081
String keyName = getS3KeyName(s3Uri);
@@ -99,8 +100,8 @@ public double ingestSingleSource(
99100
linesBuffer.add(line);
100101
lineCount++;
101102

102-
// Process every 100 lines
103-
if (lineCount % 100 == 0) {
103+
// Process every bulkSize lines
104+
if (lineCount % bulkSize == 0) {
104105
// Create a CompletableFuture that will be completed by the bulkResponseListener
105106
CompletableFuture<Void> future = new CompletableFuture<>();
106107
batchIngest(

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+50-26
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
1111
import static org.opensearch.ml.common.MLTaskState.FAILED;
1212
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
13+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE;
1314
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
1415
import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG;
1516

@@ -24,12 +25,15 @@
2425
import org.opensearch.action.support.ActionFilters;
2526
import org.opensearch.action.support.HandledTransportAction;
2627
import org.opensearch.client.Client;
28+
import org.opensearch.cluster.service.ClusterService;
2729
import org.opensearch.common.inject.Inject;
30+
import org.opensearch.common.settings.Settings;
2831
import org.opensearch.core.action.ActionListener;
2932
import org.opensearch.core.rest.RestStatus;
3033
import org.opensearch.ml.common.MLTask;
3134
import org.opensearch.ml.common.MLTaskState;
3235
import org.opensearch.ml.common.MLTaskType;
36+
import org.opensearch.ml.common.exception.MLLimitExceededException;
3337
import org.opensearch.ml.common.transport.batch.MLBatchIngestionAction;
3438
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
3539
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
@@ -60,16 +64,19 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
6064
private final Client client;
6165
private ThreadPool threadPool;
6266
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
67+
private volatile Integer batchIngestionBulkSize;
6368

6469
@Inject
6570
public TransportBatchIngestionAction(
71+
ClusterService clusterService,
6672
TransportService transportService,
6773
ActionFilters actionFilters,
6874
Client client,
6975
MLTaskManager mlTaskManager,
7076
ThreadPool threadPool,
7177
MLModelManager mlModelManager,
72-
MLFeatureEnabledSetting mlFeatureEnabledSetting
78+
MLFeatureEnabledSetting mlFeatureEnabledSetting,
79+
Settings settings
7380
) {
7481
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
7582
this.transportService = transportService;
@@ -78,6 +85,12 @@ public TransportBatchIngestionAction(
7885
this.threadPool = threadPool;
7986
this.mlModelManager = mlModelManager;
8087
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
88+
89+
batchIngestionBulkSize = ML_COMMONS_BATCH_INGESTION_BULK_SIZE.get(settings);
90+
clusterService
91+
.getClusterSettings()
92+
.addSettingsUpdateConsumer(ML_COMMONS_BATCH_INGESTION_BULK_SIZE, it -> batchIngestionBulkSize = it);
93+
8194
}
8295

8396
@Override
@@ -131,33 +144,44 @@ protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInpu
131144
.state(MLTaskState.CREATED)
132145
.build();
133146

134-
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
135-
String taskId = response.getId();
136-
try {
137-
mlTask.setTaskId(taskId);
138-
mlTaskManager.add(mlTask);
139-
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
140-
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
141-
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
142-
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
143-
executeWithErrorHandling(() -> {
144-
double successRate = ingestable.ingest(mlBatchIngestionInput);
145-
handleSuccessRate(successRate, taskId);
146-
}, taskId);
147-
});
148-
} catch (Exception ex) {
149-
log.error("Failed in batch ingestion", ex);
150-
mlTaskManager
151-
.updateMLTask(
152-
taskId,
153-
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
154-
TASK_SEMAPHORE_TIMEOUT,
155-
true
156-
);
157-
listener.onFailure(ex);
147+
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
148+
if (exceedLimits) {
149+
String error = "exceed maximum BATCH_INGEST Task limits";
150+
log.warn(error + " in task " + mlTask.getTaskId());
151+
listener.onFailure(new MLLimitExceededException(error));
152+
} else {
153+
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
154+
String taskId = response.getId();
155+
try {
156+
mlTask.setTaskId(taskId);
157+
mlTaskManager.add(mlTask);
158+
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
159+
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
160+
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
161+
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
162+
executeWithErrorHandling(() -> {
163+
double successRate = ingestable.ingest(mlBatchIngestionInput, batchIngestionBulkSize);
164+
handleSuccessRate(successRate, taskId);
165+
}, taskId);
166+
});
167+
} catch (Exception ex) {
168+
log.error("Failed in batch ingestion", ex);
169+
mlTaskManager
170+
.updateMLTask(
171+
taskId,
172+
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
173+
TASK_SEMAPHORE_TIMEOUT,
174+
true
175+
);
176+
listener.onFailure(ex);
177+
}
178+
}, exception -> {
179+
log.error("Failed to create batch ingestion task", exception);
180+
listener.onFailure(exception);
181+
}));
158182
}
159183
}, exception -> {
160-
log.error("Failed to create batch ingestion task", exception);
184+
log.error("Failed to check the maximum BATCH_INGEST Task limits", exception);
161185
listener.onFailure(exception);
162186
}));
163187
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+27
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
4141
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
4242
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
43+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS;
44+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS;
4345
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
4446
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE;
4547
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE;
@@ -107,6 +109,7 @@
107109
import org.opensearch.ml.common.MLModelGroup;
108110
import org.opensearch.ml.common.MLTask;
109111
import org.opensearch.ml.common.MLTaskState;
112+
import org.opensearch.ml.common.MLTaskType;
110113
import org.opensearch.ml.common.connector.Connector;
111114
import org.opensearch.ml.common.controller.MLController;
112115
import org.opensearch.ml.common.controller.MLRateLimiter;
@@ -177,6 +180,8 @@ public class MLModelManager {
177180
private volatile Integer maxModelPerNode;
178181
private volatile Integer maxRegisterTasksPerNode;
179182
private volatile Integer maxDeployTasksPerNode;
183+
private volatile Integer maxBatchInferenceTasks;
184+
private volatile Integer maxBatchIngestionTasks;
180185

181186
public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet
182187
.of(
@@ -232,6 +237,16 @@ public MLModelManager(
232237
clusterService
233238
.getClusterSettings()
234239
.addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it);
240+
241+
maxBatchInferenceTasks = ML_COMMONS_MAX_BATCH_INFERENCE_TASKS.get(settings);
242+
clusterService
243+
.getClusterSettings()
244+
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, it -> maxBatchInferenceTasks = it);
245+
246+
maxBatchIngestionTasks = ML_COMMONS_MAX_BATCH_INGESTION_TASKS.get(settings);
247+
clusterService
248+
.getClusterSettings()
249+
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INGESTION_TASKS, it -> maxBatchIngestionTasks = it);
235250
}
236251

237252
public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
@@ -867,6 +882,18 @@ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
867882
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
868883
}
869884

885+
/**
886+
* Check if exceed batch job task limit
887+
*
888+
* @param mlTask ML task
889+
* @param listener ActionListener if the limit is exceeded
890+
*/
891+
public void checkMaxBatchJobTask(MLTask mlTask, ActionListener<Boolean> listener) {
892+
MLTaskType taskType = mlTask.getTaskType();
893+
int maxLimit = taskType.equals(MLTaskType.BATCH_PREDICTION) ? maxBatchInferenceTasks : maxBatchIngestionTasks;
894+
mlTaskManager.checkMaxBatchJobTask(taskType, maxLimit, listener);
895+
}
896+
870897
private void updateModelRegisterStateAsDone(
871898
MLRegisterModelInput registerModelInput,
872899
String taskId,

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,10 @@ public List<Setting<?>> getSettings() {
972972
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
973973
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
974974
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED,
975-
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED
975+
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED,
976+
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS,
977+
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS,
978+
MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE
976979
);
977980
return settings;
978981
}

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ private MLCommonsSettings() {}
3434
Setting.Property.NodeScope,
3535
Setting.Property.Dynamic
3636
);
37+
38+
public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INFERENCE_TASKS = Setting
39+
.intSetting("plugins.ml_commons.max_batch_inference_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);
40+
41+
public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INGESTION_TASKS = Setting
42+
.intSetting("plugins.ml_commons.max_batch_ingestion_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);
43+
44+
public static final Setting<Integer> ML_COMMONS_BATCH_INGESTION_BULK_SIZE = Setting
45+
.intSetting("plugins.ml_commons.batch_ingestion_bulk_size", 500, 100, 100000, Setting.Property.NodeScope, Setting.Property.Dynamic);
3746
public static final Setting<Integer> ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE = Setting
3847
.intSetting("plugins.ml_commons.max_deploy_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
3948
public static final Setting<Integer> ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+26
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.opensearch.ml.common.dataset.MLInputDataType;
5656
import org.opensearch.ml.common.dataset.MLInputDataset;
5757
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
58+
import org.opensearch.ml.common.exception.MLLimitExceededException;
5859
import org.opensearch.ml.common.input.MLInput;
5960
import org.opensearch.ml.common.output.MLOutput;
6061
import org.opensearch.ml.common.output.MLPredictionOutput;
@@ -253,6 +254,31 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
253254
.lastUpdateTime(now)
254255
.async(false)
255256
.build();
257+
if (actionType.equals(ActionType.BATCH_PREDICT)) {
258+
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
259+
if (exceedLimits) {
260+
String error = "exceed maximum BATCH_PREDICTION Task limits";
261+
log.warn(error + " in task " + mlTask.getTaskId());
262+
listener.onFailure(new MLLimitExceededException(error));
263+
} else {
264+
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
265+
}
266+
}, exception -> {
267+
log.error("Failed to check the maximum BATCH_PREDICTION Task limits", exception);
268+
listener.onFailure(exception);
269+
}));
270+
}
271+
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
272+
}
273+
274+
private void executePredictionByInputDataType(
275+
MLInputDataType inputDataType,
276+
String modelId,
277+
MLInput mlInput,
278+
MLTask mlTask,
279+
FunctionName functionName,
280+
ActionListener<MLTaskResponse> listener
281+
) {
256282
switch (inputDataType) {
257283
case SEARCH_QUERY:
258284
ActionListener<MLInputDataset> dataFrameActionListener = ActionListener.wrap(dataSet -> {

0 commit comments

Comments
 (0)