|
9 | 9 | import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
|
10 | 10 | import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
|
11 | 11 | import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
|
12 |
| -import static org.opensearch.ml.common.utils.StringUtils.fromJson; |
13 | 12 |
|
14 | 13 | import java.io.BufferedReader;
|
15 | 14 | import java.io.InputStreamReader;
|
|
18 | 17 | import java.security.PrivilegedActionException;
|
19 | 18 | import java.security.PrivilegedExceptionAction;
|
20 | 19 | import java.util.ArrayList;
|
| 20 | +import java.util.HashMap; |
21 | 21 | import java.util.List;
|
22 | 22 | import java.util.Map;
|
23 | 23 | import java.util.concurrent.CompletableFuture;
|
|
33 | 33 | import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
|
34 | 34 | import org.opensearch.ml.engine.annotation.Ingester;
|
35 | 35 |
|
| 36 | +import com.jayway.jsonpath.JsonPath; |
| 37 | + |
36 | 38 | import lombok.extern.log4j.Log4j2;
|
37 | 39 | import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
38 | 40 | import software.amazon.awssdk.auth.credentials.AwsCredentials;
|
|
49 | 51 | @Ingester("s3")
|
50 | 52 | public class S3DataIngestion implements Ingestable {
|
51 | 53 | public static final String SOURCE = "source";
|
| 54 | + public static final String OUTPUT = "output"; |
| 55 | + public static final String INPUT = "input"; |
| 56 | + public static final String OUTPUTIELDS = "output_fields"; |
| 57 | + public static final String INPUTFIELDS = "input_fields"; |
| 58 | + public static final String INGESTFIELDS = "ingest_fields"; |
52 | 59 | private final Client client;
|
53 | 60 |
|
54 | 61 | public S3DataIngestion(Client client) {
|
@@ -154,31 +161,40 @@ private void batchIngest(
|
154 | 161 | ) {
|
155 | 162 | BulkRequest bulkRequest = new BulkRequest();
|
156 | 163 | sourceLines.stream().forEach(jsonStr -> {
|
157 |
| - Map<String, Object> jsonMap = fromJson(jsonStr, "SageMakerOutput"); |
158 |
| - processFieldMapping(jsonMap, mlBatchIngestionInput.getFieldMapping()); |
| 164 | + // Map<String, Object> jsonMap = fromJson(jsonStr, outputFieldName); |
| 165 | + Map<String, Object> jsonMap = processFieldMapping(jsonStr, mlBatchIngestionInput.getFieldMapping()); |
159 | 166 | IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName()).source(jsonMap);
|
160 | 167 |
|
161 | 168 | bulkRequest.add(indexRequest);
|
162 | 169 | });
|
163 | 170 | client.bulk(bulkRequest, bulkResponseListener);
|
164 | 171 | }
|
165 | 172 |
|
166 |
| - private void processFieldMapping(Map<String, Object> jsonMap, Map<String, String> fieldMapping) { |
167 |
| - List<List> smOutput = (List<List>) jsonMap.get("SageMakerOutput"); |
168 |
| - List<String> smInput = (List<String>) jsonMap.get("content"); |
169 |
| - if (smInput.size() == smOutput.size() && smInput.size() == fieldMapping.size()) { |
170 |
| - int index = 0; |
171 |
| - for (Map.Entry<String, String> mapping : fieldMapping.entrySet()) { |
172 |
| - // key is the field name for input String, value is the field name for embedded output |
173 |
| - jsonMap.put(mapping.getKey(), smInput.get(index)); |
174 |
| - jsonMap.put(mapping.getValue(), smOutput.get(index)); |
175 |
| - index++; |
176 |
| - } |
177 |
| - jsonMap.remove("content"); |
178 |
| - jsonMap.remove("SageMakerOutput"); |
179 |
| - } else { |
| 173 | + private Map<String, Object> processFieldMapping(String jsonStr, Map<String, Object> fieldMapping) { |
| 174 | + String outputJsonPath = (String) fieldMapping.get(OUTPUT); |
| 175 | + String inputJsonPath = (String) fieldMapping.get(INPUT); |
| 176 | + List<List> smOutput = (List<List>) JsonPath.read(jsonStr, outputJsonPath); |
| 177 | + List<String> smInput = (List<String>) JsonPath.read(jsonStr, inputJsonPath); |
| 178 | + List<String> inputFields = (List<String>) fieldMapping.get(INPUTFIELDS); |
| 179 | + List<String> outputFields = (List<String>) fieldMapping.get(OUTPUTIELDS); |
| 180 | + List<String> ingestFieldsJsonPath = (List<String>) fieldMapping.get(INGESTFIELDS); |
| 181 | + |
| 182 | + if (smInput.size() != smOutput.size() || inputFields.size() != outputFields.size() || smInput.size() != inputFields.size()) { |
180 | 183 | throw new IllegalArgumentException("the fieldMapping and source data do not match");
|
181 | 184 | }
|
| 185 | + Map<String, Object> jsonMap = new HashMap<>(); |
| 186 | + |
| 187 | + for (int index = 0; index < smInput.size(); index++) { |
| 188 | + jsonMap.put(inputFields.get(index), smInput.get(index)); |
| 189 | + jsonMap.put(outputFields.get(index), smOutput.get(index)); |
| 190 | + } |
| 191 | + jsonMap.remove(obtainFieldNameFromJsonPath(inputJsonPath)); |
| 192 | + jsonMap.remove(obtainFieldNameFromJsonPath(outputJsonPath)); |
| 193 | + |
| 194 | + for (String fieldPath : ingestFieldsJsonPath) { |
| 195 | + jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath)); |
| 196 | + } |
| 197 | + return jsonMap; |
182 | 198 | }
|
183 | 199 |
|
184 | 200 | private String getS3BucketName(String s3Uri) {
|
@@ -230,4 +246,11 @@ private S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) {
|
230 | 246 | throw new RuntimeException("Can't load credentials", e);
|
231 | 247 | }
|
232 | 248 | }
|
| 249 | + |
| 250 | + private String obtainFieldNameFromJsonPath(String jsonPath) { |
| 251 | + String[] parts = jsonPath.split("\\."); |
| 252 | + |
| 253 | + // Get the last part which is the field name |
| 254 | + return parts[parts.length - 1]; |
| 255 | + } |
233 | 256 | }
|
0 commit comments