Skip to content

Commit 702e8ba

Browse files
committed
use dedicated thread pool for ingestion
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 03bbdf1 commit 702e8ba

File tree

5 files changed

+105
-79
lines changed

5 files changed

+105
-79
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java

+36-26
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
99
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
1010

11-
import java.util.Arrays;
1211
import java.util.HashMap;
1312
import java.util.List;
1413
import java.util.Map;
14+
import java.util.Optional;
1515
import java.util.concurrent.CompletableFuture;
1616
import java.util.concurrent.atomic.AtomicInteger;
1717
import java.util.stream.Collectors;
@@ -96,16 +96,16 @@ protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIn
9696
Object value = entry.getValue();
9797
if (value instanceof String) {
9898
return ((String) value).contains(prefix);
99-
} else if (value instanceof String[]) {
100-
return Arrays.stream((String[]) value).anyMatch(val -> val.contains(prefix));
99+
} else if (value instanceof List) {
100+
return ((List<String>) value).stream().anyMatch(val -> val.contains(prefix));
101101
}
102102
return false;
103103
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
104104
Object value = entry.getValue();
105105
if (value instanceof String) {
106106
return value;
107-
} else if (value instanceof String[]) {
108-
return Arrays.stream((String[]) value).filter(val -> val.contains(prefix)).toArray(String[]::new);
107+
} else if (value instanceof List) {
108+
return ((List<String>) value).stream().filter(val -> val.contains(prefix)).collect(Collectors.toList());
109109
}
110110
return null;
111111
}));
@@ -136,32 +136,28 @@ protected Map<String, Object> processFieldMapping(String jsonStr, Map<String, Ob
136136
List<String> outputFieldNames = outputJsonPath != null ? (List<String>) fieldMapping.get(OUTPUT_FIELD_NAMES) : null;
137137

138138
List<String> ingestFieldsJsonPath = Optional
139-
.ofNullable((List<String>) fieldMapping.get(INGEST_FIELDS))
140-
.stream()
141-
.map(StringUtils::getJsonPath)
142-
.collect(Collectors.toList());
139+
.ofNullable((List<String>) fieldMapping.get(INGEST_FIELDS))
140+
.stream()
141+
.flatMap(java.util.Collection::stream)
142+
.map(StringUtils::getJsonPath)
143+
.collect(Collectors.toList());
143144

144-
if (remoteModelInput.size() != inputFieldNames.size() || remoteModelOutput.size() != outputFieldNames.size()) {
145-
throw new IllegalArgumentException("the fieldMapping and source data do not match");
146-
}
147145
Map<String, Object> jsonMap = new HashMap<>();
148146

149-
for (int index = 0; index < remoteModelInput.size(); index++) {
150-
jsonMap.put(inputFieldNames.get(index), remoteModelInput.get(index));
151-
jsonMap.put(outputFieldNames.get(index), remoteModelOutput.get(index));
152-
}
147+
populateJsonMap(jsonMap, inputFieldNames, remoteModelInput);
148+
populateJsonMap(jsonMap, outputFieldNames, remoteModelOutput);
153149

154150
for (String fieldPath : ingestFieldsJsonPath) {
155151
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
156152
}
157153

158154
if (fieldMapping.containsKey(ID_FIELD)) {
159155
List<String> docIdJsonPath = Optional
160-
.ofNullable((List<String>) fieldMapping.get(ID_FIELD))
161-
.stream()
162-
.flatMap(Collection::stream)
163-
.map(StringUtils::getJsonPath)
164-
.collect(Collectors.toList());
156+
.ofNullable((List<String>) fieldMapping.get(ID_FIELD))
157+
.stream()
158+
.flatMap(java.util.Collection::stream)
159+
.map(StringUtils::getJsonPath)
160+
.collect(Collectors.toList());
165161
if (docIdJsonPath.size() != 1) {
166162
throw new IllegalArgumentException("The Id field must contains only 1 jsonPath for each source");
167163
}
@@ -180,25 +176,39 @@ protected void batchIngest(
180176
BulkRequest bulkRequest = new BulkRequest();
181177
sourceLines.stream().forEach(jsonStr -> {
182178
Map<String, Object> filteredMapping = isSoleSource
183-
? mlBatchIngestionInput.getFieldMapping()
184-
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
179+
? mlBatchIngestionInput.getFieldMapping()
180+
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
185181
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
186182
if (isSoleSource || sourceIndex == 0) {
187-
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName()).source(jsonMap);
183+
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName());
188184
if (jsonMap.containsKey("_id")) {
189-
indexRequest.id((String) jsonMap.get("_id"));
185+
String id = (String) jsonMap.remove("_id");
186+
indexRequest.id(id);
190187
}
188+
indexRequest.source(jsonMap);
191189
bulkRequest.add(indexRequest);
192190
} else {
193191
// bulk update docs as they were partially ingested
194192
if (!jsonMap.containsKey("_id")) {
195193
throw new IllegalArgumentException("The id filed must be provided to match documents for multiple sources");
196194
}
197-
String id = (String) jsonMap.get("_id");
195+
String id = (String) jsonMap.remove("_id");
198196
UpdateRequest updateRequest = new UpdateRequest(mlBatchIngestionInput.getIndexName(), id).doc(jsonMap).upsert(jsonMap);
199197
bulkRequest.add(updateRequest);
200198
}
201199
});
202200
client.bulk(bulkRequest, bulkResponseListener);
203201
}
202+
203+
private void populateJsonMap(Map<String, Object> jsonMap, List<String> fieldNames, List<?> modelData) {
204+
if (modelData != null) {
205+
if (modelData.size() != fieldNames.size()) {
206+
throw new IllegalArgumentException("The fieldMapping and source data do not match");
207+
}
208+
209+
for (int index = 0; index < modelData.size(); index++) {
210+
jsonMap.put(fieldNames.get(index), modelData.get(index));
211+
}
212+
}
213+
}
204214
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java

+44-43
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,42 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
6464
connection.setRequestMethod("GET");
6565
connection.setRequestProperty("Authorization", "Bearer " + apiKey);
6666

67-
InputStreamReader inputStreamReader = AccessController
68-
.doPrivileged((PrivilegedExceptionAction<InputStreamReader>) () -> new InputStreamReader(connection.getInputStream()));
69-
BufferedReader reader = new BufferedReader(inputStreamReader);
70-
71-
List<String> linesBuffer = new ArrayList<>();
72-
String line;
73-
int lineCount = 0;
74-
// Atomic counters for tracking success and failure
75-
AtomicInteger successfulBatches = new AtomicInteger(0);
76-
AtomicInteger failedBatches = new AtomicInteger(0);
77-
// List of CompletableFutures to track batch ingestion operations
78-
List<CompletableFuture<Void>> futures = new ArrayList<>();
79-
80-
while ((line = reader.readLine()) != null) {
81-
linesBuffer.add(line);
82-
lineCount++;
83-
84-
// Process every 100 lines
85-
if (lineCount == 100) {
86-
// Create a CompletableFuture that will be completed by the bulkResponseListener
67+
try (
68+
InputStreamReader inputStreamReader = AccessController
69+
.doPrivileged((PrivilegedExceptionAction<InputStreamReader>) () -> new InputStreamReader(connection.getInputStream()));
70+
BufferedReader reader = new BufferedReader(inputStreamReader)
71+
) {
72+
List<String> linesBuffer = new ArrayList<>();
73+
String line;
74+
int lineCount = 0;
75+
// Atomic counters for tracking success and failure
76+
AtomicInteger successfulBatches = new AtomicInteger(0);
77+
AtomicInteger failedBatches = new AtomicInteger(0);
78+
// List of CompletableFutures to track batch ingestion operations
79+
List<CompletableFuture<Void>> futures = new ArrayList<>();
80+
81+
while ((line = reader.readLine()) != null) {
82+
linesBuffer.add(line);
83+
lineCount++;
84+
85+
// Process every 100 lines
86+
if (lineCount % 100 == 0) {
87+
// Create a CompletableFuture that will be completed by the bulkResponseListener
88+
CompletableFuture<Void> future = new CompletableFuture<>();
89+
batchIngest(
90+
linesBuffer,
91+
mlBatchIngestionInput,
92+
getBulkResponseListener(successfulBatches, failedBatches, future),
93+
sourceIndex,
94+
isSoleSource
95+
);
96+
97+
futures.add(future);
98+
linesBuffer.clear();
99+
}
100+
}
101+
// Process any remaining lines in the buffer
102+
if (!linesBuffer.isEmpty()) {
87103
CompletableFuture<Void> future = new CompletableFuture<>();
88104
batchIngest(
89105
linesBuffer,
@@ -92,32 +108,17 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
92108
sourceIndex,
93109
isSoleSource
94110
);
95-
96111
futures.add(future);
97-
linesBuffer.clear();
98-
lineCount = 0;
99112
}
100-
}
101-
// Process any remaining lines in the buffer
102-
if (!linesBuffer.isEmpty()) {
103-
CompletableFuture<Void> future = new CompletableFuture<>();
104-
batchIngest(
105-
linesBuffer,
106-
mlBatchIngestionInput,
107-
getBulkResponseListener(successfulBatches, failedBatches, future),
108-
sourceIndex,
109-
isSoleSource
110-
);
111-
futures.add(future);
112-
}
113113

114-
reader.close();
115-
// Combine all futures and wait for completion
116-
CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
117-
// Wait for all tasks to complete
118-
allFutures.join();
119-
int totalBatches = successfulBatches.get() + failedBatches.get();
120-
successRate = (double) successfulBatches.get() / totalBatches * 100;
114+
reader.close();
115+
// Combine all futures and wait for completion
116+
CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
117+
// Wait for all tasks to complete
118+
allFutures.join();
119+
int totalBatches = successfulBatches.get() + failedBatches.get();
120+
successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100;
121+
}
121122
} catch (PrivilegedActionException e) {
122123
throw new RuntimeException("Failed to read from OpenAI file API: ", e);
123124
} catch (Exception e) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+5-6
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ public double ingestSingleSource(
8181
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
8282
double successRate = 0;
8383

84-
try {
84+
try (
8585
ResponseInputStream<GetObjectResponse> s3is = AccessController
8686
.doPrivileged((PrivilegedExceptionAction<ResponseInputStream<GetObjectResponse>>) () -> s3.getObject(getObjectRequest));
87-
BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8));
88-
87+
BufferedReader reader = new BufferedReader(new InputStreamReader(s3is, StandardCharsets.UTF_8))
88+
) {
8989
List<String> linesBuffer = new ArrayList<>();
9090
String line;
9191
int lineCount = 0;
@@ -100,7 +100,7 @@ public double ingestSingleSource(
100100
lineCount++;
101101

102102
// Process every 100 lines
103-
if (lineCount == 100) {
103+
if (lineCount % 100 == 0) {
104104
// Create a CompletableFuture that will be completed by the bulkResponseListener
105105
CompletableFuture<Void> future = new CompletableFuture<>();
106106
batchIngest(
@@ -113,7 +113,6 @@ public double ingestSingleSource(
113113

114114
futures.add(future);
115115
linesBuffer.clear();
116-
lineCount = 0;
117116
}
118117
}
119118
// Process any remaining lines in the buffer
@@ -138,7 +137,7 @@ public double ingestSingleSource(
138137
allFutures.join();
139138

140139
int totalBatches = successfulBatches.get() + failedBatches.get();
141-
successRate = (double) successfulBatches.get() / totalBatches * 100;
140+
successRate = (totalBatches == 0) ? 100 : (double) successfulBatches.get() / totalBatches * 100;
142141
} catch (S3Exception e) {
143142
log.error("Error reading from S3: " + e.awsErrorDetails().errorMessage());
144143
throw e;

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+16-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1010
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
1111
import static org.opensearch.ml.common.MLTaskState.FAILED;
12+
import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL;
13+
import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL;
1214
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
1315

1416
import java.time.Instant;
@@ -37,6 +39,7 @@
3739
import org.opensearch.ml.task.MLTaskManager;
3840
import org.opensearch.ml.utils.MLExceptionUtils;
3941
import org.opensearch.tasks.Task;
42+
import org.opensearch.threadpool.ThreadPool;
4043
import org.opensearch.transport.TransportService;
4144

4245
import lombok.extern.log4j.Log4j2;
@@ -50,18 +53,21 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
5053
TransportService transportService;
5154
MLTaskManager mlTaskManager;
5255
private final Client client;
56+
private ThreadPool threadPool;
5357

5458
@Inject
5559
public TransportBatchIngestionAction(
5660
TransportService transportService,
5761
ActionFilters actionFilters,
5862
Client client,
59-
MLTaskManager mlTaskManager
63+
MLTaskManager mlTaskManager,
64+
ThreadPool threadPool
6065
) {
6166
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
6267
this.transportService = transportService;
6368
this.client = client;
6469
this.mlTaskManager = mlTaskManager;
70+
this.threadPool = threadPool;
6571
}
6672

6773
@Override
@@ -87,8 +93,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
8793
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
8894
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
8995
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
90-
double successRate = ingestable.ingest(mlBatchIngestionInput);
91-
handleSuccessRate(successRate, taskId);
96+
threadPool.executor(TRAIN_THREAD_POOL).execute(() -> {
97+
log.info(TRAIN_THREAD_POOL + " thread is executed!");
98+
threadPool.executor(REMOTE_PREDICT_THREAD_POOL).execute(() -> {});
99+
log.info(REMOTE_PREDICT_THREAD_POOL + " thread is executed!");
100+
threadPool.executor("unknown thread").execute(() -> {});
101+
log.info("unknown thread" + " thread is executed!");
102+
double successRate = ingestable.ingest(mlBatchIngestionInput);
103+
handleSuccessRate(successRate, taskId);
104+
});
92105
} catch (Exception ex) {
93106
log.error("Failed in batch ingestion", ex);
94107
mlTaskManager

plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.opensearch.ml.task.MLTaskManager;
4646
import org.opensearch.tasks.Task;
4747
import org.opensearch.test.OpenSearchTestCase;
48+
import org.opensearch.threadpool.ThreadPool;
4849
import org.opensearch.transport.TransportService;
4950

5051
public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@@ -62,14 +63,16 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
6263
private Task task;
6364
@Mock
6465
ActionListener<MLBatchIngestionResponse> actionListener;
66+
@Mock
67+
ThreadPool threadPool;
6568

6669
private TransportBatchIngestionAction batchAction;
6770
private MLBatchIngestionInput batchInput;
6871

6972
@Before
7073
public void setup() {
7174
MockitoAnnotations.openMocks(this);
72-
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager);
75+
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool);
7376

7477
Map<String, Object> fieldMap = new HashMap<>();
7578
fieldMap.put("input", "$.content");

0 commit comments

Comments
 (0)