Skip to content

Commit 392a307

Browse files
committed
fix field mapping, add more error handling and remove checking jobId filed in batch job response
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 93d0429 commit 392a307

File tree

4 files changed

+95
-8
lines changed

4 files changed

+95
-8
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java

+50-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,55 @@ protected double calculateSuccessRate(List<Double> successRates) {
7575
);
7676
}
7777

78+
/**
79+
* Filters fields in the map where the value contains the specified source index as a prefix.
80+
* When there is only one source file, users can skip the source[] prefix
81+
*
82+
* @param mlBatchIngestionInput The MLBatchIngestionInput.
83+
* @return A new map of <fieldName: JsonPath> for all fields to be ingested.
84+
*/
85+
protected Map<String, Object> filterFieldMappingSoleSource(MLBatchIngestionInput mlBatchIngestionInput) {
86+
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
87+
String prefix = "source[0]";
88+
89+
Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
90+
Object value = entry.getValue();
91+
if (value instanceof String) {
92+
String jsonPath = ((String) value);
93+
return jsonPath.contains(prefix) || !jsonPath.startsWith("source");
94+
} else if (value instanceof List) {
95+
return ((List<String>) value).stream().anyMatch(val -> (val.contains(prefix) || !val.startsWith("source")));
96+
}
97+
return false;
98+
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
99+
Object value = entry.getValue();
100+
if (value instanceof String) {
101+
return getJsonPath((String) value);
102+
} else if (value instanceof List) {
103+
return ((List<String>) value)
104+
.stream()
105+
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
106+
.map(StringUtils::getJsonPath)
107+
.collect(Collectors.toList());
108+
}
109+
return null;
110+
}));
111+
112+
String[] ingestFields = mlBatchIngestionInput.getIngestFields();
113+
if (ingestFields != null) {
114+
Arrays
115+
.stream(ingestFields)
116+
.filter(Objects::nonNull)
117+
.filter(val -> (val.contains(prefix) || !val.startsWith("source")))
118+
.map(StringUtils::getJsonPath)
119+
.forEach(jsonPath -> {
120+
filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath);
121+
});
122+
}
123+
124+
return filteredFieldMap;
125+
}
126+
78127
/**
79128
* Filters fields in the map where the value contains the specified source index as a prefix.
80129
*
@@ -159,7 +208,7 @@ protected void batchIngest(
159208
BulkRequest bulkRequest = new BulkRequest();
160209
sourceLines.stream().forEach(jsonStr -> {
161210
Map<String, Object> filteredMapping = isSoleSource
162-
? mlBatchIngestionInput.getFieldMapping()
211+
? filterFieldMappingSoleSource(mlBatchIngestionInput)
163212
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
164213
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
165214
if (jsonMap.isEmpty()) {

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+32-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1010
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
1111
import static org.opensearch.ml.common.MLTaskState.FAILED;
12-
import static org.opensearch.ml.plugin.MachineLearningPlugin.TRAIN_THREAD_POOL;
12+
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
1313
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
1414

1515
import java.time.Instant;
@@ -41,6 +41,8 @@
4141
import org.opensearch.threadpool.ThreadPool;
4242
import org.opensearch.transport.TransportService;
4343

44+
import com.jayway.jsonpath.PathNotFoundException;
45+
4446
import lombok.extern.log4j.Log4j2;
4547

4648
@Log4j2
@@ -92,9 +94,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
9294
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
9395
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
9496
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
95-
threadPool.executor(TRAIN_THREAD_POOL).execute(() -> {
96-
double successRate = ingestable.ingest(mlBatchIngestionInput);
97-
handleSuccessRate(successRate, taskId);
97+
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
98+
executeWithErrorHandling(() -> {
99+
double successRate = ingestable.ingest(mlBatchIngestionInput);
100+
handleSuccessRate(successRate, taskId);
101+
}, taskId);
98102
});
99103
} catch (Exception ex) {
100104
log.error("Failed in batch ingestion", ex);
@@ -125,6 +129,30 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
125129
}
126130
}
127131

132+
protected void executeWithErrorHandling(Runnable task, String taskId) {
133+
try {
134+
task.run();
135+
} catch (PathNotFoundException jsonPathNotFoundException) {
136+
log.error("Error in jsonParse fields", jsonPathNotFoundException);
137+
mlTaskManager
138+
.updateMLTask(
139+
taskId,
140+
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, jsonPathNotFoundException.getMessage()),
141+
TASK_SEMAPHORE_TIMEOUT,
142+
true
143+
);
144+
} catch (Exception e) {
145+
log.error("Error in ingest, failed to produce a successRate", e);
146+
mlTaskManager
147+
.updateMLTask(
148+
taskId,
149+
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e)),
150+
TASK_SEMAPHORE_TIMEOUT,
151+
true
152+
);
153+
}
154+
}
155+
128156
protected void handleSuccessRate(double successRate, String taskId) {
129157
if (successRate == 100) {
130158
mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, COMPLETED), 5000, true);

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+11-1
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ public class MachineLearningPlugin extends Plugin
333333
public static final String TRAIN_THREAD_POOL = "opensearch_ml_train";
334334
public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict";
335335
public static final String REMOTE_PREDICT_THREAD_POOL = "opensearch_ml_predict_remote";
336+
public static final String INGEST_THREAD_POOL = "opensearch_ml_ingest";
336337
public static final String REGISTER_THREAD_POOL = "opensearch_ml_register";
337338
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy";
338339
public static final String ML_BASE_URI = "/_plugins/_ml";
@@ -885,6 +886,14 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
885886
ML_THREAD_POOL_PREFIX + REMOTE_PREDICT_THREAD_POOL,
886887
false
887888
);
889+
FixedExecutorBuilder batchIngestThreadPool = new FixedExecutorBuilder(
890+
settings,
891+
INGEST_THREAD_POOL,
892+
OpenSearchExecutors.allocatedProcessors(settings) * 4,
893+
30,
894+
ML_THREAD_POOL_PREFIX + INGEST_THREAD_POOL,
895+
false
896+
);
888897

889898
return ImmutableList
890899
.of(
@@ -894,7 +903,8 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
894903
executeThreadPool,
895904
trainThreadPool,
896905
predictThreadPool,
897-
remotePredictThreadPool
906+
remotePredictThreadPool,
907+
batchIngestThreadPool
898908
);
899909
}
900910

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,13 @@ private void runPredict(
358358
&& tensorOutput.getMlModelOutputs() != null
359359
&& !tensorOutput.getMlModelOutputs().isEmpty()) {
360360
ModelTensors modelOutput = tensorOutput.getMlModelOutputs().get(0);
361+
Integer statusCode = modelOutput.getStatusCode();
361362
if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) {
362363
Map<String, Object> dataAsMap = (Map<String, Object>) modelOutput
363364
.getMlModelTensors()
364365
.get(0)
365366
.getDataAsMap();
366-
if (dataAsMap != null
367-
&& (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) {
367+
if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) {
368368
remoteJob.putAll(dataAsMap);
369369
mlTask.setRemoteJob(remoteJob);
370370
mlTask.setTaskId(null);

0 commit comments

Comments
 (0)