Skip to content

Commit 945b5e5

Browse files
committed
use dedicated thread pool for ingestion
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 03bbdf1 commit 945b5e5

File tree

5 files changed

+88
-69
lines changed

5 files changed

+88
-69
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java

+25-16
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.HashMap;
1313
import java.util.List;
1414
import java.util.Map;
15+
import java.util.Optional;
1516
import java.util.concurrent.CompletableFuture;
1617
import java.util.concurrent.atomic.AtomicInteger;
1718
import java.util.stream.Collectors;
@@ -136,32 +137,28 @@ protected Map<String, Object> processFieldMapping(String jsonStr, Map<String, Ob
136137
List<String> outputFieldNames = outputJsonPath != null ? (List<String>) fieldMapping.get(OUTPUT_FIELD_NAMES) : null;
137138

138139
List<String> ingestFieldsJsonPath = Optional
139-
.ofNullable((List<String>) fieldMapping.get(INGEST_FIELDS))
140-
.stream()
141-
.map(StringUtils::getJsonPath)
142-
.collect(Collectors.toList());
140+
.ofNullable((List<String>) fieldMapping.get(INGEST_FIELDS))
141+
.stream()
142+
.flatMap(java.util.Collection::stream)
143+
.map(StringUtils::getJsonPath)
144+
.collect(Collectors.toList());
143145

144-
if (remoteModelInput.size() != inputFieldNames.size() || remoteModelOutput.size() != outputFieldNames.size()) {
145-
throw new IllegalArgumentException("the fieldMapping and source data do not match");
146-
}
147146
Map<String, Object> jsonMap = new HashMap<>();
148147

149-
for (int index = 0; index < remoteModelInput.size(); index++) {
150-
jsonMap.put(inputFieldNames.get(index), remoteModelInput.get(index));
151-
jsonMap.put(outputFieldNames.get(index), remoteModelOutput.get(index));
152-
}
148+
populateJsonMap(jsonMap, inputFieldNames, remoteModelInput);
149+
populateJsonMap(jsonMap, outputFieldNames, remoteModelOutput);
153150

154151
for (String fieldPath : ingestFieldsJsonPath) {
155152
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
156153
}
157154

158155
if (fieldMapping.containsKey(ID_FIELD)) {
159156
List<String> docIdJsonPath = Optional
160-
.ofNullable((List<String>) fieldMapping.get(ID_FIELD))
161-
.stream()
162-
.flatMap(Collection::stream)
163-
.map(StringUtils::getJsonPath)
164-
.collect(Collectors.toList());
157+
.ofNullable((List<String>) fieldMapping.get(ID_FIELD))
158+
.stream()
159+
.flatMap(java.util.Collection::stream)
160+
.map(StringUtils::getJsonPath)
161+
.collect(Collectors.toList());
165162
if (docIdJsonPath.size() != 1) {
166163
throw new IllegalArgumentException("The Id field must contains only 1 jsonPath for each source");
167164
}
@@ -201,4 +198,16 @@ protected void batchIngest(
201198
});
202199
client.bulk(bulkRequest, bulkResponseListener);
203200
}
201+
202+
private void populateJsonMap(Map<String, Object> jsonMap, List<String> fieldNames, List<?> modelData) {
203+
if (modelData != null) {
204+
if (modelData.size() != fieldNames.size()) {
205+
throw new IllegalArgumentException("The fieldMapping and source data do not match");
206+
}
207+
208+
for (int index = 0; index < modelData.size(); index++) {
209+
jsonMap.put(fieldNames.get(index), modelData.get(index));
210+
}
211+
}
212+
}
204213
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java

+44-43
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,42 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
6464
connection.setRequestMethod("GET");
6565
connection.setRequestProperty("Authorization", "Bearer " + apiKey);
6666

67-
InputStreamReader inputStreamReader = AccessController
68-
.doPrivileged((PrivilegedExceptionAction<InputStreamReader>) () -> new InputStreamReader(connection.getInputStream()));
69-
BufferedReader reader = new BufferedReader(inputStreamReader);
70-
71-
List<String> linesBuffer = new ArrayList<>();
72-
String line;
73-
int lineCount = 0;
74-
// Atomic counters for tracking success and failure
75-
AtomicInteger successfulBatches = new AtomicInteger(0);
76-
AtomicInteger failedBatches = new AtomicInteger(0);
77-
// List of CompletableFutures to track batch ingestion operations
78-
List<CompletableFuture<Void>> futures = new ArrayList<>();
79-
80-
while ((line = reader.readLine()) != null) {
81-
linesBuffer.add(line);
82-
lineCount++;
83-
84-
// Process every 100 lines
85-
if (lineCount == 100) {
86-
// Create a CompletableFuture that will be completed by the bulkResponseListener
67+
try (
68+
InputStreamReader inputStreamReader = AccessController
69+
.doPrivileged((PrivilegedExceptionAction<InputStreamReader>) () -> new InputStreamReader(connection.getInputStream()));
70+
BufferedReader reader = new BufferedReader(inputStreamReader)
71+
) {
72+
List<String> linesBuffer = new ArrayList<>();
73+
String line;
74+
int lineCount = 0;
75+
// Atomic counters for tracking success and failure
76+
AtomicInteger successfulBatches = new AtomicInteger(0);
77+
AtomicInteger failedBatches = new AtomicInteger(0);
78+
// List of CompletableFutures to track batch ingestion operations
79+
List<CompletableFuture<Void>> futures = new ArrayList<>();
80+
81+
while ((line = reader.readLine()) != null) {
82+
linesBuffer.add(line);
83+
lineCount++;
84+
85+
// Process every 100 lines
86+
if (lineCount % 100 == 0) {
87+
// Create a CompletableFuture that will be completed by the bulkResponseListener
88+
CompletableFuture<Void> future = new CompletableFuture<>();
89+
batchIngest(
90+
linesBuffer,
91+
mlBatchIngestionInput,
92+
getBulkResponseListener(successfulBatches, failedBatches, future),
93+
sourceIndex,
94+
isSoleSource
95+
);
96+
97+
futures.add(future);
98+
linesBuffer.clear();
99+
}
100+
}
101+
// Process any remaining lines in the buffer
102+
if (!linesBuffer.isEmpty()) {
87103
CompletableFuture<Void> future = new CompletableFuture<>();
88104
batchIngest(
89105
linesBuffer,
@@ -92,32 +108,17 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
92108
sourceIndex,
93109
isSoleSource
94110
);
95-
96111
futures.add(future);
97-
linesBuffer.clear();
98-
lineCount = 0;
99112
}
100-
}
101-
// Process any remaining lines in the buffer
102-
if (!linesBuffer.isEmpty()) {
103-
CompletableFuture<Void> future = new CompletableFuture<>();
104-
batchIngest(
105-
linesBuffer,
106-
mlBatchIngestionInput,
107-
getBulkResponseListener(successfulBatches, failedBatches, future),
108-
sourceIndex,
109-
isSoleSource
110-
);
111-
futures.add(future);
112-
}
113113

114-
reader.close();
115-
// Combine all futures and wait for completion
116-
CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
117-
// Wait for all tasks to complete
118-
allFutures.join();
119-
int totalBatches = successfulBatches.get() + failedBatches.get();
120-
successRate = (double) successfulBatches.get() / totalBatches * 100;
114+
reader.close();
115+
// Combine all futures and wait for completion
116+
CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
117+
// Wait for all tasks to complete
118+
allFutures.join();
119+
int totalBatches = successfulBatches.get() + failedBatches.get();
120+
successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100;
121+
}
121122
} catch (PrivilegedActionException e) {
122123
throw new RuntimeException("Failed to read from OpenAI file API: ", e);
123124
} catch (Exception e) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+5-6
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ public double ingestSingleSource(
8181
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
8282
double successRate = 0;
8383

84-
try {
84+
try (
8585
ResponseInputStream<GetObjectResponse> s3is = AccessController
8686
.doPrivileged((PrivilegedExceptionAction<ResponseInputStream<GetObjectResponse>>) () -> s3.getObject(getObjectRequest));
87-
BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8));
88-
87+
BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8))
88+
) {
8989
List<String> linesBuffer = new ArrayList<>();
9090
String line;
9191
int lineCount = 0;
@@ -100,7 +100,7 @@ public double ingestSingleSource(
100100
lineCount++;
101101

102102
// Process every 100 lines
103-
if (lineCount == 100) {
103+
if (lineCount % 100 == 0) {
104104
// Create a CompletableFuture that will be completed by the bulkResponseListener
105105
CompletableFuture<Void> future = new CompletableFuture<>();
106106
batchIngest(
@@ -113,7 +113,6 @@ public double ingestSingleSource(
113113

114114
futures.add(future);
115115
linesBuffer.clear();
116-
lineCount = 0;
117116
}
118117
}
119118
// Process any remaining lines in the buffer
@@ -138,7 +137,7 @@ public double ingestSingleSource(
138137
allFutures.join();
139138

140139
int totalBatches = successfulBatches.get() + failedBatches.get();
141-
successRate = (double) successfulBatches.get() / totalBatches * 100;
140+
successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100;
142141
} catch (S3Exception e) {
143142
log.error("Error reading from S3: " + e.awsErrorDetails().errorMessage());
144143
throw e;

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+10-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1010
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
1111
import static org.opensearch.ml.common.MLTaskState.FAILED;
12+
import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL;
1213
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
1314

1415
import java.time.Instant;
@@ -37,6 +38,7 @@
3738
import org.opensearch.ml.task.MLTaskManager;
3839
import org.opensearch.ml.utils.MLExceptionUtils;
3940
import org.opensearch.tasks.Task;
41+
import org.opensearch.threadpool.ThreadPool;
4042
import org.opensearch.transport.TransportService;
4143

4244
import lombok.extern.log4j.Log4j2;
@@ -50,18 +52,21 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
5052
TransportService transportService;
5153
MLTaskManager mlTaskManager;
5254
private final Client client;
55+
private ThreadPool threadPool;
5356

5457
@Inject
5558
public TransportBatchIngestionAction(
5659
TransportService transportService,
5760
ActionFilters actionFilters,
5861
Client client,
59-
MLTaskManager mlTaskManager
62+
MLTaskManager mlTaskManager,
63+
ThreadPool threadPool
6064
) {
6165
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
6266
this.transportService = transportService;
6367
this.client = client;
6468
this.mlTaskManager = mlTaskManager;
69+
this.threadPool = threadPool;
6570
}
6671

6772
@Override
@@ -87,8 +92,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
8792
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
8893
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
8994
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
90-
double successRate = ingestable.ingest(mlBatchIngestionInput);
91-
handleSuccessRate(successRate, taskId);
95+
threadPool.executor(TRAIN_THREAD_POOL).execute(() -> {
96+
double successRate = ingestable.ingest(mlBatchIngestionInput);
97+
handleSuccessRate(successRate, taskId);
98+
});
9299
} catch (Exception ex) {
93100
log.error("Failed in batch ingestion", ex);
94101
mlTaskManager

plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.opensearch.ml.task.MLTaskManager;
4646
import org.opensearch.tasks.Task;
4747
import org.opensearch.test.OpenSearchTestCase;
48+
import org.opensearch.threadpool.ThreadPool;
4849
import org.opensearch.transport.TransportService;
4950

5051
public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@@ -62,14 +63,16 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
6263
private Task task;
6364
@Mock
6465
ActionListener<MLBatchIngestionResponse> actionListener;
66+
@Mock
67+
ThreadPool threadPool;
6568

6669
private TransportBatchIngestionAction batchAction;
6770
private MLBatchIngestionInput batchInput;
6871

6972
@Before
7073
public void setup() {
7174
MockitoAnnotations.openMocks(this);
72-
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager);
75+
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool);
7376

7477
Map<String, Object> fieldMap = new HashMap<>();
7578
fieldMap.put("input", "$.content");

0 commit comments

Comments
 (0)