Skip to content

Commit 9e7e2bb

Browse files
committed
update interphase and address comments
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent c8d7b8a commit 9e7e2bb

File tree

6 files changed

+217
-32
lines changed

6 files changed

+217
-32
lines changed

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+15-10
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.ml.common.transport.batch;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9-
import static org.opensearch.ml.common.utils.StringUtils.getOrderedMap;
109

1110
import java.io.IOException;
1211
import java.util.HashMap;
@@ -28,13 +27,13 @@
2827
public class MLBatchIngestionInput implements ToXContentObject, Writeable {
2928

3029
public static final String INDEX_NAME_FIELD = "index_name";
31-
public static final String TEXT_EMBEDDING_FIELD_MAP_FIELD = "text_embedding_field_map";
30+
public static final String FIELD_MAP_FIELD = "field_map";
3231
public static final String DATA_SOURCE_FIELD = "data_source";
3332
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
3433
@Getter
3534
private String indexName;
3635
@Getter
37-
private Map<String, String> fieldMapping;
36+
private Map<String, Object> fieldMapping;
3837
@Getter
3938
private Map<String, String> dataSources;
4039
@Getter
@@ -43,10 +42,16 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
4342
@Builder(toBuilder = true)
4443
public MLBatchIngestionInput(
4544
String indexName,
46-
Map<String, String> fieldMapping,
45+
Map<String, Object> fieldMapping,
4746
Map<String, String> dataSources,
4847
Map<String, String> credential
4948
) {
49+
if (indexName == null) {
50+
throw new IllegalArgumentException("index name for ingestion is null");
51+
}
52+
if (dataSources == null) {
53+
throw new IllegalArgumentException("dataSources for ingestion is null");
54+
}
5055
this.indexName = indexName;
5156
this.fieldMapping = fieldMapping;
5257
this.dataSources = dataSources;
@@ -55,7 +60,7 @@ public MLBatchIngestionInput(
5560

5661
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
5762
String indexName = null;
58-
Map<String, String> fieldMapping = null;
63+
Map<String, Object> fieldMapping = null;
5964
Map<String, String> dataSources = null;
6065
Map<String, String> credential = new HashMap<>();
6166

@@ -68,8 +73,8 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
6873
case INDEX_NAME_FIELD:
6974
indexName = parser.text();
7075
break;
71-
case TEXT_EMBEDDING_FIELD_MAP_FIELD:
72-
fieldMapping = getOrderedMap(parser.mapOrdered());
76+
case FIELD_MAP_FIELD:
77+
fieldMapping = parser.map();
7378
break;
7479
case CONNECTOR_CREDENTIAL_FIELD:
7580
credential = parser.mapStrings();
@@ -92,7 +97,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
9297
builder.field(INDEX_NAME_FIELD, indexName);
9398
}
9499
if (fieldMapping != null) {
95-
builder.field(TEXT_EMBEDDING_FIELD_MAP_FIELD, fieldMapping);
100+
builder.field(FIELD_MAP_FIELD, fieldMapping);
96101
}
97102
if (dataSources != null) {
98103
builder.field(DATA_SOURCE_FIELD, dataSources);
@@ -109,7 +114,7 @@ public void writeTo(StreamOutput output) throws IOException {
109114
output.writeOptionalString(indexName);
110115
if (fieldMapping != null) {
111116
output.writeBoolean(true);
112-
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeString);
117+
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeGenericValue);
113118
} else {
114119
output.writeBoolean(false);
115120
}
@@ -132,7 +137,7 @@ public void writeTo(StreamOutput output) throws IOException {
132137
public MLBatchIngestionInput(StreamInput input) throws IOException {
133138
indexName = input.readOptionalString();
134139
if (input.readBoolean()) {
135-
fieldMapping = input.readMap(s -> s.readString(), s -> s.readString());
140+
fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue());
136141
}
137142
if (input.readBoolean()) {
138143
dataSources = input.readMap(s -> s.readString(), s -> s.readString());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.batch;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
11+
import java.io.IOException;
12+
import java.util.Collections;
13+
import java.util.HashMap;
14+
import java.util.Map;
15+
import java.util.function.Consumer;
16+
17+
import org.junit.Before;
18+
import org.junit.Rule;
19+
import org.junit.Test;
20+
import org.junit.rules.ExpectedException;
21+
import org.opensearch.common.io.stream.BytesStreamOutput;
22+
import org.opensearch.common.settings.Settings;
23+
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
24+
import org.opensearch.common.xcontent.XContentFactory;
25+
import org.opensearch.common.xcontent.XContentType;
26+
import org.opensearch.core.common.io.stream.StreamInput;
27+
import org.opensearch.core.xcontent.NamedXContentRegistry;
28+
import org.opensearch.core.xcontent.ToXContent;
29+
import org.opensearch.core.xcontent.XContentBuilder;
30+
import org.opensearch.core.xcontent.XContentParser;
31+
import org.opensearch.search.SearchModule;
32+
33+
public class MLBatchIngestionInputTests {
34+
35+
private MLBatchIngestionInput mlBatchIngestionInput;
36+
37+
private Map<String, String> dataSource;
38+
39+
@Rule
40+
public final ExpectedException exceptionRule = ExpectedException.none();
41+
42+
private final String expectedInputStr = "{"
43+
+ "\"index_name\":\"test index\","
44+
+ "\"text_embedding_field_map\":{"
45+
+ "\"chapter\":\"chapter_embedding\""
46+
+ "},"
47+
+ "\"data_source\":{"
48+
+ "\"source\":\"s3://samplebucket/output/sampleresults.json.out\","
49+
+ "\"type\":\"s3\""
50+
+ "},"
51+
+ "\"credential\":{"
52+
+ "\"region\":\"test region\""
53+
+ "}"
54+
+ "}";
55+
56+
@Before
57+
public void setUp() {
58+
dataSource = new HashMap<>();
59+
dataSource.put("type", "s3");
60+
dataSource.put("source", "s3://samplebucket/output/sampleresults.json.out");
61+
62+
Map<String, String> credentials = Map.of("region", "test region");
63+
Map<String, String> fieldMapping = Map.of("chapter", "chapter_embedding");
64+
65+
mlBatchIngestionInput = MLBatchIngestionInput
66+
.builder()
67+
.indexName("test index")
68+
.credential(credentials)
69+
.fieldMapping(fieldMapping)
70+
.dataSources(dataSource)
71+
.build();
72+
}
73+
74+
@Test
75+
public void constructorMLBatchIngestionInput_NullName() {
76+
exceptionRule.expect(IllegalArgumentException.class);
77+
exceptionRule.expectMessage("index name for ingestion is null");
78+
79+
MLBatchIngestionInput.builder().indexName(null).dataSources(dataSource).build();
80+
}
81+
82+
@Test
83+
public void constructorMLBatchIngestionInput_NullSource() {
84+
exceptionRule.expect(IllegalArgumentException.class);
85+
exceptionRule.expectMessage("dataSources for ingestion is null");
86+
MLBatchIngestionInput.builder().indexName("test index").dataSources(null).build();
87+
}
88+
89+
@Test
90+
public void testToXContent_FullFields() throws Exception {
91+
XContentBuilder builder = XContentFactory.jsonBuilder();
92+
mlBatchIngestionInput.toXContent(builder, ToXContent.EMPTY_PARAMS);
93+
assertNotNull(builder);
94+
String jsonStr = builder.toString();
95+
assertEquals(expectedInputStr, jsonStr);
96+
}
97+
98+
@Test
99+
public void testParse() throws Exception {
100+
testParseFromJsonString(expectedInputStr, parsedInput -> {
101+
assertEquals("test index", parsedInput.getIndexName());
102+
assertEquals("test region", parsedInput.getCredential().get("region"));
103+
assertEquals("chapter_embedding", parsedInput.getFieldMapping().get("chapter"));
104+
assertEquals("s3", parsedInput.getDataSources().get("type"));
105+
});
106+
}
107+
108+
private void testParseFromJsonString(String expectedInputString, Consumer<MLBatchIngestionInput> verify) throws Exception {
109+
XContentParser parser = XContentType.JSON
110+
.xContent()
111+
.createParser(
112+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
113+
LoggingDeprecationHandler.INSTANCE,
114+
expectedInputString
115+
);
116+
parser.nextToken();
117+
MLBatchIngestionInput parsedInput = MLBatchIngestionInput.parse(parser);
118+
verify.accept(parsedInput);
119+
}
120+
121+
@Test
122+
public void readInputStream_Success() throws IOException {
123+
readInputStream(
124+
mlBatchIngestionInput,
125+
parsedInput -> assertEquals(mlBatchIngestionInput.getIndexName(), parsedInput.getIndexName())
126+
);
127+
}
128+
129+
private void readInputStream(MLBatchIngestionInput input, Consumer<MLBatchIngestionInput> verify) throws IOException {
130+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
131+
input.writeTo(bytesStreamOutput);
132+
StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
133+
MLBatchIngestionInput parsedInput = new MLBatchIngestionInput(streamInput);
134+
verify.accept(parsedInput);
135+
}
136+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.batch;
7+
8+
public class MLBatchIngestionRequestTests {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.batch;
7+
8+
public class MLBatchIngestionResponseTests {}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/openAIDataIngestion.java ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.engine.ingest;
27

38
import java.io.BufferedReader;
@@ -30,14 +35,14 @@
3035

3136
@Log4j2
3237
@Ingester("openai")
33-
public class openAIDataIngestion implements Ingestable {
38+
public class OpenAIDataIngestion implements Ingestable {
3439
private static final String API_KEY = "openAI_key";
3540
private static final String API_URL = "https://api.openai.com/v1/files/";
3641

3742
public static final String SOURCE = "source";
3843
private final Client client;
3944

40-
public openAIDataIngestion(Client client) {
45+
public OpenAIDataIngestion(Client client) {
4146
this.client = client;
4247
}
4348

@@ -145,14 +150,14 @@ private void batchIngest(
145150
client.bulk(bulkRequest, bulkResponseListener);
146151
}
147152

148-
private Map<String, Object> processFieldMapping(String customId, JSONArray dataArray, Map<String, String> fieldMapping) {
153+
private Map<String, Object> processFieldMapping(String customId, JSONArray dataArray, Map<String, Object> fieldMapping) {
149154
Map<String, Object> jsonMap = new HashMap<>();
150155
if (dataArray.length() == fieldMapping.size()) {
151156
int index = 0;
152-
for (Map.Entry<String, String> mapping : fieldMapping.entrySet()) {
157+
for (Map.Entry<String, Object> mapping : fieldMapping.entrySet()) {
153158
// key is the field name for input String, value is the field name for embedded output
154159
JSONObject dataItem = dataArray.getJSONObject(index);
155-
jsonMap.put(mapping.getValue(), dataItem.getJSONArray("embedding"));
160+
jsonMap.put((String) mapping.getValue(), dataItem.getJSONArray("embedding"));
156161
index++;
157162
}
158163
jsonMap.put("id", customId);

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+40-17
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
1010
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
1111
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
12-
import static org.opensearch.ml.common.utils.StringUtils.fromJson;
1312

1413
import java.io.BufferedReader;
1514
import java.io.InputStreamReader;
@@ -18,6 +17,7 @@
1817
import java.security.PrivilegedActionException;
1918
import java.security.PrivilegedExceptionAction;
2019
import java.util.ArrayList;
20+
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.concurrent.CompletableFuture;
@@ -33,6 +33,8 @@
3333
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
3434
import org.opensearch.ml.engine.annotation.Ingester;
3535

36+
import com.jayway.jsonpath.JsonPath;
37+
3638
import lombok.extern.log4j.Log4j2;
3739
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
3840
import software.amazon.awssdk.auth.credentials.AwsCredentials;
@@ -49,6 +51,11 @@
4951
@Ingester("s3")
5052
public class S3DataIngestion implements Ingestable {
5153
public static final String SOURCE = "source";
54+
public static final String OUTPUT = "output";
55+
public static final String INPUT = "input";
56+
public static final String OUTPUTIELDS = "output_fields";
57+
public static final String INPUTFIELDS = "input_fields";
58+
public static final String INGESTFIELDS = "ingest_fields";
5259
private final Client client;
5360

5461
public S3DataIngestion(Client client) {
@@ -154,31 +161,40 @@ private void batchIngest(
154161
) {
155162
BulkRequest bulkRequest = new BulkRequest();
156163
sourceLines.stream().forEach(jsonStr -> {
157-
Map<String, Object> jsonMap = fromJson(jsonStr, "SageMakerOutput");
158-
processFieldMapping(jsonMap, mlBatchIngestionInput.getFieldMapping());
164+
// Map<String, Object> jsonMap = fromJson(jsonStr, outputFieldName);
165+
Map<String, Object> jsonMap = processFieldMapping(jsonStr, mlBatchIngestionInput.getFieldMapping());
159166
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName()).source(jsonMap);
160167

161168
bulkRequest.add(indexRequest);
162169
});
163170
client.bulk(bulkRequest, bulkResponseListener);
164171
}
165172

166-
private void processFieldMapping(Map<String, Object> jsonMap, Map<String, String> fieldMapping) {
167-
List<List> smOutput = (List<List>) jsonMap.get("SageMakerOutput");
168-
List<String> smInput = (List<String>) jsonMap.get("content");
169-
if (smInput.size() == smOutput.size() && smInput.size() == fieldMapping.size()) {
170-
int index = 0;
171-
for (Map.Entry<String, String> mapping : fieldMapping.entrySet()) {
172-
// key is the field name for input String, value is the field name for embedded output
173-
jsonMap.put(mapping.getKey(), smInput.get(index));
174-
jsonMap.put(mapping.getValue(), smOutput.get(index));
175-
index++;
176-
}
177-
jsonMap.remove("content");
178-
jsonMap.remove("SageMakerOutput");
179-
} else {
173+
private Map<String, Object> processFieldMapping(String jsonStr, Map<String, Object> fieldMapping) {
174+
String outputJsonPath = (String) fieldMapping.get(OUTPUT);
175+
String inputJsonPath = (String) fieldMapping.get(INPUT);
176+
List<List> smOutput = (List<List>) JsonPath.read(jsonStr, outputJsonPath);
177+
List<String> smInput = (List<String>) JsonPath.read(jsonStr, inputJsonPath);
178+
List<String> inputFields = (List<String>) fieldMapping.get(INPUTFIELDS);
179+
List<String> outputFields = (List<String>) fieldMapping.get(OUTPUTIELDS);
180+
List<String> ingestFieldsJsonPath = (List<String>) fieldMapping.get(INGESTFIELDS);
181+
182+
if (smInput.size() != smOutput.size() || inputFields.size() != outputFields.size() || smInput.size() != inputFields.size()) {
180183
throw new IllegalArgumentException("the fieldMapping and source data do not match");
181184
}
185+
Map<String, Object> jsonMap = new HashMap<>();
186+
187+
for (int index = 0; index < smInput.size(); index++) {
188+
jsonMap.put(inputFields.get(index), smInput.get(index));
189+
jsonMap.put(outputFields.get(index), smOutput.get(index));
190+
}
191+
jsonMap.remove(obtainFieldNameFromJsonPath(inputJsonPath));
192+
jsonMap.remove(obtainFieldNameFromJsonPath(outputJsonPath));
193+
194+
for (String fieldPath : ingestFieldsJsonPath) {
195+
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
196+
}
197+
return jsonMap;
182198
}
183199

184200
private String getS3BucketName(String s3Uri) {
@@ -230,4 +246,11 @@ private S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) {
230246
throw new RuntimeException("Can't load credentials", e);
231247
}
232248
}
249+
250+
private String obtainFieldNameFromJsonPath(String jsonPath) {
251+
String[] parts = jsonPath.split("\\.");
252+
253+
// Get the last part which is the field name
254+
return parts[parts.length - 1];
255+
}
233256
}

0 commit comments

Comments
 (0)