Skip to content

Commit e0fb651

Browse files
committed
update the field mapping for batch ingest
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 822d62b commit e0fb651

File tree

3 files changed

+61
-64
lines changed

3 files changed

+61
-64
lines changed

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+24-2
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
2828

2929
public static final String INDEX_NAME_FIELD = "index_name";
3030
public static final String FIELD_MAP_FIELD = "field_map";
31-
public static final String DATA_SOURCE_FIELD = "data_source";
31+
public static final String INGEST_FIELDS = "ingest_fields";
3232
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
33+
public static final String DATA_SOURCE_FIELD = "data_source";
34+
3335
@Getter
3436
private String indexName;
3537
@Getter
3638
private Map<String, Object> fieldMapping;
3739
@Getter
40+
private String[] ingestFields;
41+
@Getter
3842
private Map<String, Object> dataSources;
3943
@Getter
4044
private Map<String, String> credential;
@@ -43,6 +47,7 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
4347
public MLBatchIngestionInput(
4448
String indexName,
4549
Map<String, Object> fieldMapping,
50+
String[] ingestFields,
4651
Map<String, Object> dataSources,
4752
Map<String, String> credential
4853
) {
@@ -58,13 +63,15 @@ public MLBatchIngestionInput(
5863
}
5964
this.indexName = indexName;
6065
this.fieldMapping = fieldMapping;
66+
this.ingestFields = ingestFields;
6167
this.dataSources = dataSources;
6268
this.credential = credential;
6369
}
6470

6571
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
6672
String indexName = null;
6773
Map<String, Object> fieldMapping = null;
74+
String[] ingestFields = null;
6875
Map<String, Object> dataSources = null;
6976
Map<String, String> credential = new HashMap<>();
7077

@@ -80,6 +87,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
8087
case FIELD_MAP_FIELD:
8188
fieldMapping = parser.map();
8289
break;
90+
case INGEST_FIELDS:
91+
ingestFields = parser.list().toArray(new String[0]);
92+
break;
8393
case CONNECTOR_CREDENTIAL_FIELD:
8494
credential = parser.mapStrings();
8595
break;
@@ -91,7 +101,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
91101
break;
92102
}
93103
}
94-
return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential);
104+
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
95105
}
96106

97107
@Override
@@ -103,6 +113,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
103113
if (fieldMapping != null) {
104114
builder.field(FIELD_MAP_FIELD, fieldMapping);
105115
}
116+
if (ingestFields != null) {
117+
builder.field(INGEST_FIELDS, ingestFields);
118+
}
106119
if (credential != null) {
107120
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
108121
}
@@ -122,6 +135,12 @@ public void writeTo(StreamOutput output) throws IOException {
122135
} else {
123136
output.writeBoolean(false);
124137
}
138+
if (ingestFields != null) {
139+
output.writeBoolean(true);
140+
output.writeStringArray(ingestFields);
141+
} else {
142+
output.writeBoolean(false);
143+
}
125144
if (credential != null) {
126145
output.writeBoolean(true);
127146
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
@@ -141,6 +160,9 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
141160
if (input.readBoolean()) {
142161
fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue());
143162
}
163+
if (input.readBoolean()) {
164+
ingestFields = input.readStringArray();;
165+
}
144166
if (input.readBoolean()) {
145167
credential = input.readMap(s -> s.readString(), s -> s.readString());
146168
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java

+37-56
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
package org.opensearch.ml.engine.ingest;
77

8+
import static org.opensearch.ml.common.transport.batch.MLBatchIngestionInput.INGEST_FIELDS;
89
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
910
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
1011

11-
import java.util.Collection;
12+
import java.util.Arrays;
1213
import java.util.HashMap;
1314
import java.util.List;
1415
import java.util.Map;
15-
import java.util.Optional;
16+
import java.util.Objects;
1617
import java.util.concurrent.CompletableFuture;
1718
import java.util.concurrent.atomic.AtomicInteger;
1819
import java.util.stream.Collectors;
@@ -34,12 +35,6 @@
3435

3536
@Log4j2
3637
public class AbstractIngestion implements Ingestable {
37-
public static final String OUTPUT = "output";
38-
public static final String INPUT = "input";
39-
public static final String OUTPUT_FIELD_NAMES = "output_names";
40-
public static final String INPUT_FIELD_NAMES = "input_names";
41-
public static final String INGEST_FIELDS = "ingest_fields";
42-
public static final String ID_FIELD = "id_field";
4338

4439
private final Client client;
4540

@@ -85,12 +80,11 @@ protected double calculateSuccessRate(List<Double> successRates) {
8580
* Filters fields in the map where the value contains the specified source index as a prefix.
8681
*
8782
* @param mlBatchIngestionInput The MLBatchIngestionInput.
88-
* @param index The source index to filter by.
89-
* @return A new map with only the entries that match the specified source index.
83+
* @param indexInFieldMap The source index to filter by.
84+
* @return A new map with only the entries that match the specified source index and correctly mapped to JsonPath.
9085
*/
91-
protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int index) {
86+
protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int indexInFieldMap) {
9287
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
93-
int indexInFieldMap = index + 1;
9488
String prefix = "source[" + indexInFieldMap + "]";
9589

9690
Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
@@ -104,19 +98,28 @@ protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIn
10498
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
10599
Object value = entry.getValue();
106100
if (value instanceof String) {
107-
return value;
101+
return getJsonPath((String) value);
108102
} else if (value instanceof List) {
109-
return ((List<String>) value).stream().filter(val -> val.contains(prefix)).collect(Collectors.toList());
103+
return ((List<String>) value)
104+
.stream()
105+
.filter(val -> val.contains(prefix))
106+
.map(StringUtils::getJsonPath)
107+
.collect(Collectors.toList());
110108
}
111109
return null;
112110
}));
113111

114-
if (filteredFieldMap.containsKey(OUTPUT)) {
115-
filteredFieldMap.put(OUTPUT_FIELD_NAMES, fieldMap.get(OUTPUT_FIELD_NAMES));
116-
}
117-
if (filteredFieldMap.containsKey(INPUT)) {
118-
filteredFieldMap.put(INPUT_FIELD_NAMES, fieldMap.get(INPUT_FIELD_NAMES));
112+
String[] ingestFields = mlBatchIngestionInput.getIngestFields();
113+
if (ingestFields != null) {
114+
Arrays.stream(ingestFields)
115+
.filter(Objects::nonNull)
116+
.filter(val -> val.contains(prefix))
117+
.map(StringUtils::getJsonPath)
118+
.forEach(jsonPath -> {
119+
filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath);
120+
});
119121
}
122+
120123
return filteredFieldMap;
121124
}
122125

@@ -128,42 +131,21 @@ protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIn
128131
* @return A new map that contains all the fields and data for ingestion.
129132
*/
130133
protected Map<String, Object> processFieldMapping(String jsonStr, Map<String, Object> fieldMapping) {
131-
String inputJsonPath = fieldMapping.containsKey(INPUT) ? getJsonPath((String) fieldMapping.get(INPUT)) : null;
132-
List<String> remoteModelInput = inputJsonPath != null ? (List<String>) JsonPath.read(jsonStr, inputJsonPath) : null;
133-
List<String> inputFieldNames = inputJsonPath != null ? (List<String>) fieldMapping.get(INPUT_FIELD_NAMES) : null;
134-
135-
String outputJsonPath = fieldMapping.containsKey(OUTPUT) ? getJsonPath((String) fieldMapping.get(OUTPUT)) : null;
136-
List<List> remoteModelOutput = outputJsonPath != null ? (List<List>) JsonPath.read(jsonStr, outputJsonPath) : null;
137-
List<String> outputFieldNames = outputJsonPath != null ? (List<String>) fieldMapping.get(OUTPUT_FIELD_NAMES) : null;
138-
139-
List<String> ingestFieldsJsonPath = Optional
140-
.ofNullable((List<String>) fieldMapping.get(INGEST_FIELDS))
141-
.stream()
142-
.flatMap(Collection::stream)
143-
.map(StringUtils::getJsonPath)
144-
.collect(Collectors.toList());
145-
146134
Map<String, Object> jsonMap = new HashMap<>();
147-
148-
populateJsonMap(jsonMap, inputFieldNames, remoteModelInput);
149-
populateJsonMap(jsonMap, outputFieldNames, remoteModelOutput);
150-
151-
for (String fieldPath : ingestFieldsJsonPath) {
152-
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
135+
if (fieldMapping == null || fieldMapping.isEmpty()) {
136+
return jsonMap;
153137
}
154138

155-
if (fieldMapping.containsKey(ID_FIELD)) {
156-
List<String> docIdJsonPath = Optional
157-
.ofNullable((List<String>) fieldMapping.get(ID_FIELD))
158-
.stream()
159-
.flatMap(Collection::stream)
160-
.map(StringUtils::getJsonPath)
161-
.collect(Collectors.toList());
162-
if (docIdJsonPath.size() != 1) {
163-
throw new IllegalArgumentException("The Id field must contains only 1 jsonPath for each source");
139+
fieldMapping.entrySet().stream().forEach(entry -> {
140+
Object value = entry.getValue();
141+
if (value instanceof String) {
142+
String jsonPath = (String) value;
143+
jsonMap.put(entry.getKey(), JsonPath.read(jsonStr, jsonPath));
144+
} else if (value instanceof List) {
145+
((List<String>) value).stream().forEach(jsonPath -> { jsonMap.put(entry.getKey(), JsonPath.read(jsonStr, jsonPath)); });
164146
}
165-
jsonMap.put("_id", JsonPath.read(jsonStr, docIdJsonPath.get(0)));
166-
}
147+
});
148+
167149
return jsonMap;
168150
}
169151

@@ -180,12 +162,11 @@ protected void batchIngest(
180162
? mlBatchIngestionInput.getFieldMapping()
181163
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
182164
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
183-
if (isSoleSource || sourceIndex == 0) {
165+
if (jsonMap.isEmpty()) {
166+
return;
167+
}
168+
if (isSoleSource && !jsonMap.containsKey("_id")) {
184169
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName());
185-
if (jsonMap.containsKey("_id")) {
186-
String id = (String) jsonMap.remove("_id");
187-
indexRequest.id(id);
188-
}
189170
indexRequest.source(jsonMap);
190171
bulkRequest.add(indexRequest);
191172
} else {

ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java

-6
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
import static org.mockito.Mockito.mock;
1616
import static org.mockito.Mockito.verify;
1717
import static org.mockito.Mockito.when;
18-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.ID_FIELD;
19-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.INGEST_FIELDS;
20-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT;
21-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT_FIELD_NAMES;
22-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT;
23-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT_FIELD_NAMES;
2418

2519
import java.util.Arrays;
2620
import java.util.Collections;

0 commit comments

Comments
 (0)