Skip to content

Commit 4e3b97a

Browse files
committed
support multiple data sources as ingestion inputs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 5c574ae commit 4e3b97a

File tree

10 files changed

+327
-207
lines changed

10 files changed

+327
-207
lines changed

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionAction.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.transport.batch;
27

38
import org.opensearch.action.ActionType;

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+13-15
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
3535
@Getter
3636
private Map<String, Object> fieldMapping;
3737
@Getter
38-
private Map<String, String> dataSources;
38+
private Map<String, Object> dataSources;
3939
@Getter
4040
private Map<String, String> credential;
4141

4242
@Builder(toBuilder = true)
4343
public MLBatchIngestionInput(
4444
String indexName,
4545
Map<String, Object> fieldMapping,
46-
Map<String, String> dataSources,
46+
Map<String, Object> dataSources,
4747
Map<String, String> credential
4848
) {
4949
if (indexName == null) {
@@ -61,7 +61,7 @@ public MLBatchIngestionInput(
6161
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
6262
String indexName = null;
6363
Map<String, Object> fieldMapping = null;
64-
Map<String, String> dataSources = null;
64+
Map<String, Object> dataSources = null;
6565
Map<String, String> credential = new HashMap<>();
6666

6767
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
@@ -80,7 +80,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
8080
credential = parser.mapStrings();
8181
break;
8282
case DATA_SOURCE_FIELD:
83-
dataSources = parser.mapStrings();
83+
dataSources = parser.map();
8484
break;
8585
default:
8686
parser.skipChildren();
@@ -99,12 +99,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
9999
if (fieldMapping != null) {
100100
builder.field(FIELD_MAP_FIELD, fieldMapping);
101101
}
102-
if (dataSources != null) {
103-
builder.field(DATA_SOURCE_FIELD, dataSources);
104-
}
105102
if (credential != null) {
106103
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
107104
}
105+
if (dataSources != null) {
106+
builder.field(DATA_SOURCE_FIELD, dataSources);
107+
}
108108
builder.endObject();
109109
return builder;
110110
}
@@ -118,17 +118,15 @@ public void writeTo(StreamOutput output) throws IOException {
118118
} else {
119119
output.writeBoolean(false);
120120
}
121-
122-
if (dataSources != null) {
121+
if (credential != null) {
123122
output.writeBoolean(true);
124-
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeString);
123+
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
125124
} else {
126125
output.writeBoolean(false);
127126
}
128-
129-
if (credential != null) {
127+
if (dataSources != null) {
130128
output.writeBoolean(true);
131-
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
129+
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
132130
} else {
133131
output.writeBoolean(false);
134132
}
@@ -140,10 +138,10 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
140138
fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue());
141139
}
142140
if (input.readBoolean()) {
143-
dataSources = input.readMap(s -> s.readString(), s -> s.readString());
141+
credential = input.readMap(s -> s.readString(), s -> s.readString());
144142
}
145143
if (input.readBoolean()) {
146-
credential = input.readMap(s -> s.readString(), s -> s.readString());
144+
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
147145
}
148146
}
149147

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionRequest.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.transport.batch;
27

38
import static org.opensearch.action.ValidateActions.addValidationError;

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionResponse.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.transport.batch;
27

38
import java.io.ByteArrayInputStream;

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

+8-22
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import java.util.ArrayList;
1414
import java.util.HashMap;
1515
import java.util.HashSet;
16-
import java.util.LinkedHashMap;
1716
import java.util.List;
1817
import java.util.Map;
1918
import java.util.Set;
@@ -143,27 +142,6 @@ public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs)
143142
return parameters;
144143
}
145144

146-
@SuppressWarnings("removal")
147-
public static LinkedHashMap<String, String> getOrderedMap(Map<String, ?> parameterObjs) {
148-
LinkedHashMap<String, String> parameters = new LinkedHashMap<>();
149-
for (String key : parameterObjs.keySet()) {
150-
Object value = parameterObjs.get(key);
151-
try {
152-
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
153-
if (value instanceof String) {
154-
parameters.put(key, (String) value);
155-
} else {
156-
parameters.put(key, gson.toJson(value));
157-
}
158-
return null;
159-
});
160-
} catch (PrivilegedActionException e) {
161-
throw new RuntimeException(e);
162-
}
163-
}
164-
return parameters;
165-
}
166-
167145
@SuppressWarnings("removal")
168146
public static String toJson(Object value) {
169147
try {
@@ -262,4 +240,12 @@ public static String obtainFieldNameFromJsonPath(String jsonPath) {
262240
// Get the last part which is the field name
263241
return parts[parts.length - 1];
264242
}
243+
244+
public static String getJsonPath(String jsonPathWithSource) {
245+
// Find the index of the first occurrence of "$."
246+
int startIndex = jsonPathWithSource.indexOf("$.");
247+
248+
// Extract the substring from the startIndex to the end of the input string
249+
return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource;
250+
}
265251
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ public void invokeRemoteService(
8686
ActionListener<Tuple<Integer, ModelTensors>> actionListener
8787
) {
8888
try {
89-
connector.getDecryptedCredential();
9089
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
9190
AsyncExecuteRequest executeRequest = AsyncExecuteRequest
9291
.builder()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.ingest;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
9+
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
10+
11+
import java.util.Arrays;
12+
import java.util.HashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.concurrent.CompletableFuture;
16+
import java.util.concurrent.atomic.AtomicInteger;
17+
import java.util.stream.Collectors;
18+
19+
import org.opensearch.OpenSearchStatusException;
20+
import org.opensearch.action.bulk.BulkRequest;
21+
import org.opensearch.action.bulk.BulkResponse;
22+
import org.opensearch.action.index.IndexRequest;
23+
import org.opensearch.action.update.UpdateRequest;
24+
import org.opensearch.client.Client;
25+
import org.opensearch.core.action.ActionListener;
26+
import org.opensearch.core.rest.RestStatus;
27+
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
28+
import org.opensearch.ml.common.utils.StringUtils;
29+
30+
import com.jayway.jsonpath.JsonPath;
31+
32+
import lombok.extern.log4j.Log4j2;
33+
34+
@Log4j2
35+
public class AbstractIngestion implements Ingestable {
36+
public static final String OUTPUT = "output";
37+
public static final String INPUT = "input";
38+
public static final String OUTPUTIELDS = "output_names";
39+
public static final String INPUTFIELDS = "input_names";
40+
public static final String INGESTFIELDS = "ingest_fields";
41+
public static final String IDFIELD = "id_field";
42+
43+
private final Client client;
44+
45+
public AbstractIngestion(Client client) {
46+
this.client = client;
47+
}
48+
49+
protected ActionListener<BulkResponse> getBulkResponseListener(
50+
AtomicInteger successfulBatches,
51+
AtomicInteger failedBatches,
52+
CompletableFuture<Void> future
53+
) {
54+
return ActionListener.wrap(bulkResponse -> {
55+
if (bulkResponse.hasFailures()) {
56+
failedBatches.incrementAndGet();
57+
future.completeExceptionally(new RuntimeException(bulkResponse.buildFailureMessage())); // Mark the future as completed
58+
// with an exception
59+
}
60+
log.debug("Batch Ingestion successfully");
61+
successfulBatches.incrementAndGet();
62+
future.complete(null); // Mark the future as completed successfully
63+
}, e -> {
64+
log.error("Failed to Batch Ingestion", e);
65+
failedBatches.incrementAndGet();
66+
future.completeExceptionally(e); // Mark the future as completed with an exception
67+
});
68+
}
69+
70+
protected double calcualteSuccessRate(List<Double> successRates) {
71+
return successRates
72+
.stream()
73+
.min(Double::compare)
74+
.orElseThrow(
75+
() -> new OpenSearchStatusException(
76+
"Failed to batch ingest data as not success rate is returned",
77+
RestStatus.INTERNAL_SERVER_ERROR
78+
)
79+
);
80+
}
81+
82+
/**
83+
* Filters fields in the map where the value contains the specified source index as a prefix.
84+
*
85+
* @param mlBatchIngestionInput The MLBatchIngestionInput.
86+
* @param index The source index to filter by.
87+
* @return A new map with only the entries that match the specified source index.
88+
*/
89+
protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int index) {
90+
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
91+
int indexInFieldMap = index + 1;
92+
String prefix = "source[" + indexInFieldMap + "]";
93+
94+
Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
95+
Object value = entry.getValue();
96+
if (value instanceof String) {
97+
return ((String) value).contains(prefix);
98+
} else if (value instanceof String[]) {
99+
return Arrays.stream((String[]) value).anyMatch(val -> val.contains(prefix));
100+
}
101+
return false;
102+
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
103+
Object value = entry.getValue();
104+
if (value instanceof String) {
105+
return value;
106+
} else if (value instanceof String[]) {
107+
return Arrays.stream((String[]) value).filter(val -> val.contains(prefix)).toArray(String[]::new);
108+
}
109+
return null;
110+
}));
111+
112+
if (filteredFieldMap.containsKey(OUTPUT)) {
113+
filteredFieldMap.put(OUTPUTIELDS, fieldMap.get(OUTPUTIELDS));
114+
}
115+
if (filteredFieldMap.containsKey(INPUT)) {
116+
filteredFieldMap.put(INPUTFIELDS, fieldMap.get(INPUTFIELDS));
117+
}
118+
return filteredFieldMap;
119+
}
120+
121+
/**
122+
* Produce the source as a Map to be ingested in to OpenSearch.
123+
*
124+
* @param jsonStr The MLBatchIngestionInput.
125+
* @param fieldMapping The field mapping that includes all the field name and Json Path for the data.
126+
* @return A new map that contains all the fields and data for ingestion.
127+
*/
128+
protected Map<String, Object> processFieldMapping(String jsonStr, Map<String, Object> fieldMapping) {
129+
String inputJsonPath = fieldMapping.containsKey(INPUT) ? getJsonPath((String) fieldMapping.get(INPUT)) : null;
130+
List<String> remoteModelInput = (List<String>) JsonPath.read(jsonStr, inputJsonPath);
131+
List<String> inputFieldNames = inputJsonPath != null ? (List<String>) fieldMapping.get(INPUTFIELDS) : null;
132+
133+
String outputJsonPath = fieldMapping.containsKey(OUTPUT) ? getJsonPath((String) fieldMapping.get(OUTPUT)) : null;
134+
List<List> remoteModelOutput = (List<List>) JsonPath.read(jsonStr, outputJsonPath);
135+
List<String> outputFieldNames = outputJsonPath != null ? (List<String>) fieldMapping.get(OUTPUTIELDS) : null;
136+
137+
List<String> ingestFieldsJsonPath = ((List<String>) fieldMapping.get(INGESTFIELDS))
138+
.stream()
139+
.map(StringUtils::getJsonPath)
140+
.collect(Collectors.toList());
141+
142+
if (remoteModelInput.size() != inputFieldNames.size() || remoteModelOutput.size() != outputFieldNames.size()) {
143+
throw new IllegalArgumentException("the fieldMapping and source data do not match");
144+
}
145+
Map<String, Object> jsonMap = new HashMap<>();
146+
147+
for (int index = 0; index < remoteModelInput.size(); index++) {
148+
jsonMap.put(inputFieldNames.get(index), remoteModelInput.get(index));
149+
jsonMap.put(outputFieldNames.get(index), remoteModelOutput.get(index));
150+
}
151+
152+
for (String fieldPath : ingestFieldsJsonPath) {
153+
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
154+
}
155+
156+
if (fieldMapping.containsKey(IDFIELD)) {
157+
List<String> docIdJsonPath = ((List<String>) fieldMapping.get(IDFIELD))
158+
.stream()
159+
.map(StringUtils::getJsonPath)
160+
.collect(Collectors.toList());
161+
if (docIdJsonPath.size() != 1) {
162+
throw new IllegalArgumentException("The Id field must contains only 1 jsonPath for each source");
163+
}
164+
jsonMap.put("_id", JsonPath.read(jsonStr, docIdJsonPath.get(0)));
165+
}
166+
return jsonMap;
167+
}
168+
169+
protected void batchIngest(
170+
List<String> sourceLines,
171+
MLBatchIngestionInput mlBatchIngestionInput,
172+
ActionListener<BulkResponse> bulkResponseListener,
173+
int sourceIndex,
174+
boolean isSoleSource
175+
) {
176+
BulkRequest bulkRequest = new BulkRequest();
177+
sourceLines.stream().forEach(jsonStr -> {
178+
Map<String, Object> filteredMapping = isSoleSource
179+
? mlBatchIngestionInput.getFieldMapping()
180+
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
181+
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
182+
if (isSoleSource || sourceIndex == 0) {
183+
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName()).source(jsonMap);
184+
if (jsonMap.containsKey("_id")) {
185+
indexRequest.id((String) jsonMap.get("_id"));
186+
}
187+
bulkRequest.add(indexRequest);
188+
} else {
189+
// bulk update docs as they were partially ingested
190+
if (!jsonMap.containsKey("_id")) {
191+
throw new IllegalArgumentException("The id filed must be provided to match documents for multiple sources");
192+
}
193+
String id = (String) jsonMap.get("_id");
194+
UpdateRequest updateRequest = new UpdateRequest(mlBatchIngestionInput.getIndexName(), id).doc(jsonMap).upsert(jsonMap);
195+
bulkRequest.add(updateRequest);
196+
}
197+
});
198+
client.bulk(bulkRequest, bulkResponseListener);
199+
}
200+
}

0 commit comments

Comments
 (0)