Skip to content

Commit 9518939

Browse files
committed
update the field mapping for batch ingest
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 55d28e0 commit 9518939

File tree

3 files changed

+59
-64
lines changed

3 files changed

+59
-64
lines changed

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+24-2
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
2828

2929
public static final String INDEX_NAME_FIELD = "index_name";
3030
public static final String FIELD_MAP_FIELD = "field_map";
31-
public static final String DATA_SOURCE_FIELD = "data_source";
31+
public static final String INGEST_FIELDS = "ingest_fields";
3232
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
33+
public static final String DATA_SOURCE_FIELD = "data_source";
34+
3335
@Getter
3436
private String indexName;
3537
@Getter
3638
private Map<String, Object> fieldMapping;
3739
@Getter
40+
private Map<String, Object> ingestFields;
41+
@Getter
3842
private Map<String, Object> dataSources;
3943
@Getter
4044
private Map<String, String> credential;
@@ -43,6 +47,7 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
4347
public MLBatchIngestionInput(
4448
String indexName,
4549
Map<String, Object> fieldMapping,
50+
Map<String, Object> ingestFields,
4651
Map<String, Object> dataSources,
4752
Map<String, String> credential
4853
) {
@@ -58,13 +63,15 @@ public MLBatchIngestionInput(
5863
}
5964
this.indexName = indexName;
6065
this.fieldMapping = fieldMapping;
66+
this.ingestFields = ingestFields;
6167
this.dataSources = dataSources;
6268
this.credential = credential;
6369
}
6470

6571
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
6672
String indexName = null;
6773
Map<String, Object> fieldMapping = null;
74+
Map<String, Object> ingestFields = null;
6875
Map<String, Object> dataSources = null;
6976
Map<String, String> credential = new HashMap<>();
7077

@@ -80,6 +87,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
8087
case FIELD_MAP_FIELD:
8188
fieldMapping = parser.map();
8289
break;
90+
case INGEST_FIELDS:
91+
ingestFields = parser.map();
92+
break;
8393
case CONNECTOR_CREDENTIAL_FIELD:
8494
credential = parser.mapStrings();
8595
break;
@@ -91,7 +101,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
91101
break;
92102
}
93103
}
94-
return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential);
104+
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
95105
}
96106

97107
@Override
@@ -103,6 +113,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
103113
if (fieldMapping != null) {
104114
builder.field(FIELD_MAP_FIELD, fieldMapping);
105115
}
116+
if (ingestFields != null) {
117+
builder.field(INGEST_FIELDS, ingestFields);
118+
}
106119
if (credential != null) {
107120
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
108121
}
@@ -122,6 +135,12 @@ public void writeTo(StreamOutput output) throws IOException {
122135
} else {
123136
output.writeBoolean(false);
124137
}
138+
if (ingestFields != null) {
139+
output.writeBoolean(true);
140+
output.writeMap(ingestFields, StreamOutput::writeString, StreamOutput::writeGenericValue);
141+
} else {
142+
output.writeBoolean(false);
143+
}
125144
if (credential != null) {
126145
output.writeBoolean(true);
127146
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
@@ -141,6 +160,9 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
141160
if (input.readBoolean()) {
142161
fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue());
143162
}
163+
if (input.readBoolean()) {
164+
ingestFields = input.readMap(s -> s.readString(), s -> s.readGenericValue());
165+
}
144166
if (input.readBoolean()) {
145167
credential = input.readMap(s -> s.readString(), s -> s.readString());
146168
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/AbstractIngestion.java

+35-56
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55

66
package org.opensearch.ml.engine.ingest;
77

8+
import static org.opensearch.ml.common.transport.batch.MLBatchIngestionInput.INGEST_FIELDS;
89
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
910
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
1011

11-
import java.util.Collection;
1212
import java.util.HashMap;
1313
import java.util.List;
1414
import java.util.Map;
15-
import java.util.Optional;
1615
import java.util.concurrent.CompletableFuture;
1716
import java.util.concurrent.atomic.AtomicInteger;
1817
import java.util.stream.Collectors;
@@ -34,12 +33,6 @@
3433

3534
@Log4j2
3635
public class AbstractIngestion implements Ingestable {
37-
public static final String OUTPUT = "output";
38-
public static final String INPUT = "input";
39-
public static final String OUTPUT_FIELD_NAMES = "output_names";
40-
public static final String INPUT_FIELD_NAMES = "input_names";
41-
public static final String INGEST_FIELDS = "ingest_fields";
42-
public static final String ID_FIELD = "id_field";
4336

4437
private final Client client;
4538

@@ -85,12 +78,11 @@ protected double calculateSuccessRate(List<Double> successRates) {
8578
* Filters fields in the map where the value contains the specified source index as a prefix.
8679
*
8780
* @param mlBatchIngestionInput The MLBatchIngestionInput.
88-
* @param index The source index to filter by.
89-
* @return A new map with only the entries that match the specified source index.
81+
* @param indexInFieldMap The source index to filter by.
82+
* @return A new map with only the entries that match the specified source index and correctly mapped to JsonPath.
9083
*/
91-
protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int index) {
84+
protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int indexInFieldMap) {
9285
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
93-
int indexInFieldMap = index + 1;
9486
String prefix = "source[" + indexInFieldMap + "]";
9587

9688
Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
@@ -104,19 +96,28 @@ protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIn
10496
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
10597
Object value = entry.getValue();
10698
if (value instanceof String) {
107-
return value;
99+
return getJsonPath((String) value);
108100
} else if (value instanceof List) {
109-
return ((List<String>) value).stream().filter(val -> val.contains(prefix)).collect(Collectors.toList());
101+
return ((List<String>) value)
102+
.stream()
103+
.filter(val -> val.contains(prefix))
104+
.map(StringUtils::getJsonPath)
105+
.collect(Collectors.toList());
110106
}
111107
return null;
112108
}));
113109

114-
if (filteredFieldMap.containsKey(OUTPUT)) {
115-
filteredFieldMap.put(OUTPUT_FIELD_NAMES, fieldMap.get(OUTPUT_FIELD_NAMES));
116-
}
117-
if (filteredFieldMap.containsKey(INPUT)) {
118-
filteredFieldMap.put(INPUT_FIELD_NAMES, fieldMap.get(INPUT_FIELD_NAMES));
110+
Map<String, Object> ingestFields = mlBatchIngestionInput.getIngestFields();
111+
if (ingestFields != null && ingestFields.get(INGEST_FIELDS) instanceof List) {
112+
((List<String>) ingestFields.get(INGEST_FIELDS))
113+
.stream()
114+
.filter(val -> val.contains(prefix))
115+
.map(StringUtils::getJsonPath)
116+
.forEach(jsonPath -> {
117+
filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath);
118+
});
119119
}
120+
120121
return filteredFieldMap;
121122
}
122123

@@ -128,42 +129,21 @@ protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIn
128129
* @return A new map that contains all the fields and data for ingestion.
129130
*/
130131
protected Map<String, Object> processFieldMapping(String jsonStr, Map<String, Object> fieldMapping) {
131-
String inputJsonPath = fieldMapping.containsKey(INPUT) ? getJsonPath((String) fieldMapping.get(INPUT)) : null;
132-
List<String> remoteModelInput = inputJsonPath != null ? (List<String>) JsonPath.read(jsonStr, inputJsonPath) : null;
133-
List<String> inputFieldNames = inputJsonPath != null ? (List<String>) fieldMapping.get(INPUT_FIELD_NAMES) : null;
134-
135-
String outputJsonPath = fieldMapping.containsKey(OUTPUT) ? getJsonPath((String) fieldMapping.get(OUTPUT)) : null;
136-
List<List> remoteModelOutput = outputJsonPath != null ? (List<List>) JsonPath.read(jsonStr, outputJsonPath) : null;
137-
List<String> outputFieldNames = outputJsonPath != null ? (List<String>) fieldMapping.get(OUTPUT_FIELD_NAMES) : null;
138-
139-
List<String> ingestFieldsJsonPath = Optional
140-
.ofNullable((List<String>) fieldMapping.get(INGEST_FIELDS))
141-
.stream()
142-
.flatMap(Collection::stream)
143-
.map(StringUtils::getJsonPath)
144-
.collect(Collectors.toList());
145-
146132
Map<String, Object> jsonMap = new HashMap<>();
147-
148-
populateJsonMap(jsonMap, inputFieldNames, remoteModelInput);
149-
populateJsonMap(jsonMap, outputFieldNames, remoteModelOutput);
150-
151-
for (String fieldPath : ingestFieldsJsonPath) {
152-
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
133+
if (fieldMapping == null || fieldMapping.isEmpty()) {
134+
return jsonMap;
153135
}
154136

155-
if (fieldMapping.containsKey(ID_FIELD)) {
156-
List<String> docIdJsonPath = Optional
157-
.ofNullable((List<String>) fieldMapping.get(ID_FIELD))
158-
.stream()
159-
.flatMap(Collection::stream)
160-
.map(StringUtils::getJsonPath)
161-
.collect(Collectors.toList());
162-
if (docIdJsonPath.size() != 1) {
163-
throw new IllegalArgumentException("The Id field must contains only 1 jsonPath for each source");
137+
fieldMapping.entrySet().stream().forEach(entry -> {
138+
Object value = entry.getValue();
139+
if (value instanceof String) {
140+
String jsonPath = (String) value;
141+
jsonMap.put(entry.getKey(), JsonPath.read(jsonStr, jsonPath));
142+
} else if (value instanceof List) {
143+
((List<String>) value).stream().forEach(jsonPath -> { jsonMap.put(entry.getKey(), JsonPath.read(jsonStr, jsonPath)); });
164144
}
165-
jsonMap.put("_id", JsonPath.read(jsonStr, docIdJsonPath.get(0)));
166-
}
145+
});
146+
167147
return jsonMap;
168148
}
169149

@@ -180,12 +160,11 @@ protected void batchIngest(
180160
? mlBatchIngestionInput.getFieldMapping()
181161
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
182162
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
183-
if (isSoleSource || sourceIndex == 0) {
163+
if (jsonMap.isEmpty()) {
164+
return;
165+
}
166+
if (isSoleSource && !jsonMap.containsKey("_id")) {
184167
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName());
185-
if (jsonMap.containsKey("_id")) {
186-
String id = (String) jsonMap.remove("_id");
187-
indexRequest.id(id);
188-
}
189168
indexRequest.source(jsonMap);
190169
bulkRequest.add(indexRequest);
191170
} else {

ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java

-6
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
import static org.mockito.Mockito.mock;
1616
import static org.mockito.Mockito.verify;
1717
import static org.mockito.Mockito.when;
18-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.ID_FIELD;
19-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.INGEST_FIELDS;
20-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT;
21-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.INPUT_FIELD_NAMES;
22-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT;
23-
import static org.opensearch.ml.engine.ingest.AbstractIngestion.OUTPUT_FIELD_NAMES;
2418

2519
import java.util.Arrays;
2620
import java.util.Collections;

0 commit comments

Comments
 (0)