Skip to content

Commit 5c574ae

Browse files
committed
update interphase and address comments
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent c8d7b8a commit 5c574ae

File tree

7 files changed

+235
-47
lines changed

7 files changed

+235
-47
lines changed

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+15-10
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.ml.common.transport.batch;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9-
import static org.opensearch.ml.common.utils.StringUtils.getOrderedMap;
109

1110
import java.io.IOException;
1211
import java.util.HashMap;
@@ -28,13 +27,13 @@
2827
public class MLBatchIngestionInput implements ToXContentObject, Writeable {
2928

3029
public static final String INDEX_NAME_FIELD = "index_name";
31-
public static final String TEXT_EMBEDDING_FIELD_MAP_FIELD = "text_embedding_field_map";
30+
public static final String FIELD_MAP_FIELD = "field_map";
3231
public static final String DATA_SOURCE_FIELD = "data_source";
3332
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
3433
@Getter
3534
private String indexName;
3635
@Getter
37-
private Map<String, String> fieldMapping;
36+
private Map<String, Object> fieldMapping;
3837
@Getter
3938
private Map<String, String> dataSources;
4039
@Getter
@@ -43,10 +42,16 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
4342
@Builder(toBuilder = true)
4443
public MLBatchIngestionInput(
4544
String indexName,
46-
Map<String, String> fieldMapping,
45+
Map<String, Object> fieldMapping,
4746
Map<String, String> dataSources,
4847
Map<String, String> credential
4948
) {
49+
if (indexName == null) {
50+
throw new IllegalArgumentException("index name for ingestion is null");
51+
}
52+
if (dataSources == null) {
53+
throw new IllegalArgumentException("dataSources for ingestion is null");
54+
}
5055
this.indexName = indexName;
5156
this.fieldMapping = fieldMapping;
5257
this.dataSources = dataSources;
@@ -55,7 +60,7 @@ public MLBatchIngestionInput(
5560

5661
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
5762
String indexName = null;
58-
Map<String, String> fieldMapping = null;
63+
Map<String, Object> fieldMapping = null;
5964
Map<String, String> dataSources = null;
6065
Map<String, String> credential = new HashMap<>();
6166

@@ -68,8 +73,8 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
6873
case INDEX_NAME_FIELD:
6974
indexName = parser.text();
7075
break;
71-
case TEXT_EMBEDDING_FIELD_MAP_FIELD:
72-
fieldMapping = getOrderedMap(parser.mapOrdered());
76+
case FIELD_MAP_FIELD:
77+
fieldMapping = parser.map();
7378
break;
7479
case CONNECTOR_CREDENTIAL_FIELD:
7580
credential = parser.mapStrings();
@@ -92,7 +97,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
9297
builder.field(INDEX_NAME_FIELD, indexName);
9398
}
9499
if (fieldMapping != null) {
95-
builder.field(TEXT_EMBEDDING_FIELD_MAP_FIELD, fieldMapping);
100+
builder.field(FIELD_MAP_FIELD, fieldMapping);
96101
}
97102
if (dataSources != null) {
98103
builder.field(DATA_SOURCE_FIELD, dataSources);
@@ -109,7 +114,7 @@ public void writeTo(StreamOutput output) throws IOException {
109114
output.writeOptionalString(indexName);
110115
if (fieldMapping != null) {
111116
output.writeBoolean(true);
112-
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeString);
117+
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeGenericValue);
113118
} else {
114119
output.writeBoolean(false);
115120
}
@@ -132,7 +137,7 @@ public void writeTo(StreamOutput output) throws IOException {
132137
public MLBatchIngestionInput(StreamInput input) throws IOException {
133138
indexName = input.readOptionalString();
134139
if (input.readBoolean()) {
135-
fieldMapping = input.readMap(s -> s.readString(), s -> s.readString());
140+
fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue());
136141
}
137142
if (input.readBoolean()) {
138143
dataSources = input.readMap(s -> s.readString(), s -> s.readString());

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

+7
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,11 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
255255
return errorMessage + " Model ID: " + modelId;
256256
}
257257
}
258+
259+
public static String obtainFieldNameFromJsonPath(String jsonPath) {
260+
String[] parts = jsonPath.split("\\.");
261+
262+
// Get the last part which is the field name
263+
return parts[parts.length - 1];
264+
}
258265
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.batch;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
11+
import java.io.IOException;
12+
import java.util.Collections;
13+
import java.util.HashMap;
14+
import java.util.Map;
15+
import java.util.function.Consumer;
16+
17+
import org.junit.Before;
18+
import org.junit.Rule;
19+
import org.junit.Test;
20+
import org.junit.rules.ExpectedException;
21+
import org.opensearch.common.io.stream.BytesStreamOutput;
22+
import org.opensearch.common.settings.Settings;
23+
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
24+
import org.opensearch.common.xcontent.XContentFactory;
25+
import org.opensearch.common.xcontent.XContentType;
26+
import org.opensearch.core.common.io.stream.StreamInput;
27+
import org.opensearch.core.xcontent.NamedXContentRegistry;
28+
import org.opensearch.core.xcontent.ToXContent;
29+
import org.opensearch.core.xcontent.XContentBuilder;
30+
import org.opensearch.core.xcontent.XContentParser;
31+
import org.opensearch.search.SearchModule;
32+
33+
public class MLBatchIngestionInputTests {
34+
35+
private MLBatchIngestionInput mlBatchIngestionInput;
36+
37+
private Map<String, String> dataSource;
38+
39+
@Rule
40+
public final ExpectedException exceptionRule = ExpectedException.none();
41+
42+
private final String expectedInputStr = "{"
43+
+ "\"index_name\":\"test index\","
44+
+ "\"text_embedding_field_map\":{"
45+
+ "\"chapter\":\"chapter_embedding\""
46+
+ "},"
47+
+ "\"data_source\":{"
48+
+ "\"source\":\"s3://samplebucket/output/sampleresults.json.out\","
49+
+ "\"type\":\"s3\""
50+
+ "},"
51+
+ "\"credential\":{"
52+
+ "\"region\":\"test region\""
53+
+ "}"
54+
+ "}";
55+
56+
@Before
57+
public void setUp() {
58+
dataSource = new HashMap<>();
59+
dataSource.put("type", "s3");
60+
dataSource.put("source", "s3://samplebucket/output/sampleresults.json.out");
61+
62+
Map<String, String> credentials = Map.of("region", "test region");
63+
Map<String, String> fieldMapping = Map.of("chapter", "chapter_embedding");
64+
65+
mlBatchIngestionInput = MLBatchIngestionInput
66+
.builder()
67+
.indexName("test index")
68+
.credential(credentials)
69+
.fieldMapping(fieldMapping)
70+
.dataSources(dataSource)
71+
.build();
72+
}
73+
74+
@Test
75+
public void constructorMLBatchIngestionInput_NullName() {
76+
exceptionRule.expect(IllegalArgumentException.class);
77+
exceptionRule.expectMessage("index name for ingestion is null");
78+
79+
MLBatchIngestionInput.builder().indexName(null).dataSources(dataSource).build();
80+
}
81+
82+
@Test
83+
public void constructorMLBatchIngestionInput_NullSource() {
84+
exceptionRule.expect(IllegalArgumentException.class);
85+
exceptionRule.expectMessage("dataSources for ingestion is null");
86+
MLBatchIngestionInput.builder().indexName("test index").dataSources(null).build();
87+
}
88+
89+
@Test
90+
public void testToXContent_FullFields() throws Exception {
91+
XContentBuilder builder = XContentFactory.jsonBuilder();
92+
mlBatchIngestionInput.toXContent(builder, ToXContent.EMPTY_PARAMS);
93+
assertNotNull(builder);
94+
String jsonStr = builder.toString();
95+
assertEquals(expectedInputStr, jsonStr);
96+
}
97+
98+
@Test
99+
public void testParse() throws Exception {
100+
testParseFromJsonString(expectedInputStr, parsedInput -> {
101+
assertEquals("test index", parsedInput.getIndexName());
102+
assertEquals("test region", parsedInput.getCredential().get("region"));
103+
assertEquals("chapter_embedding", parsedInput.getFieldMapping().get("chapter"));
104+
assertEquals("s3", parsedInput.getDataSources().get("type"));
105+
});
106+
}
107+
108+
private void testParseFromJsonString(String expectedInputString, Consumer<MLBatchIngestionInput> verify) throws Exception {
109+
XContentParser parser = XContentType.JSON
110+
.xContent()
111+
.createParser(
112+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
113+
LoggingDeprecationHandler.INSTANCE,
114+
expectedInputString
115+
);
116+
parser.nextToken();
117+
MLBatchIngestionInput parsedInput = MLBatchIngestionInput.parse(parser);
118+
verify.accept(parsedInput);
119+
}
120+
121+
@Test
122+
public void readInputStream_Success() throws IOException {
123+
readInputStream(
124+
mlBatchIngestionInput,
125+
parsedInput -> assertEquals(mlBatchIngestionInput.getIndexName(), parsedInput.getIndexName())
126+
);
127+
}
128+
129+
private void readInputStream(MLBatchIngestionInput input, Consumer<MLBatchIngestionInput> verify) throws IOException {
130+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
131+
input.writeTo(bytesStreamOutput);
132+
StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
133+
MLBatchIngestionInput parsedInput = new MLBatchIngestionInput(streamInput);
134+
verify.accept(parsedInput);
135+
}
136+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.batch;
7+
8+
public class MLBatchIngestionRequestTests {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.batch;
7+
8+
public class MLBatchIngestionResponseTests {}

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/openAIDataIngestion.java ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/OpenAIDataIngestion.java

+29-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.engine.ingest;
27

38
import java.io.BufferedReader;
@@ -14,8 +19,7 @@
1419
import java.util.concurrent.CompletableFuture;
1520
import java.util.concurrent.atomic.AtomicInteger;
1621

17-
import org.json.JSONArray;
18-
import org.json.JSONObject;
22+
import com.jayway.jsonpath.JsonPath;
1923
import org.opensearch.OpenSearchStatusException;
2024
import org.opensearch.action.bulk.BulkRequest;
2125
import org.opensearch.action.bulk.BulkResponse;
@@ -28,16 +32,21 @@
2832

2933
import lombok.extern.log4j.Log4j2;
3034

35+
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
36+
import static org.opensearch.ml.engine.ingest.S3DataIngestion.INGESTFIELDS;
37+
import static org.opensearch.ml.engine.ingest.S3DataIngestion.OUTPUT;
38+
import static org.opensearch.ml.engine.ingest.S3DataIngestion.OUTPUTIELDS;
39+
3140
@Log4j2
3241
@Ingester("openai")
33-
public class openAIDataIngestion implements Ingestable {
42+
public class OpenAIDataIngestion implements Ingestable {
3443
private static final String API_KEY = "openAI_key";
3544
private static final String API_URL = "https://api.openai.com/v1/files/";
3645

3746
public static final String SOURCE = "source";
3847
private final Client client;
3948

40-
public openAIDataIngestion(Client client) {
49+
public OpenAIDataIngestion(Client client) {
4150
this.client = client;
4251
}
4352

@@ -133,32 +142,32 @@ private void batchIngest(
133142
) {
134143
BulkRequest bulkRequest = new BulkRequest();
135144
sourceLines.stream().forEach(jsonStr -> {
136-
JSONObject jsonObject = new JSONObject(jsonStr);
137-
String customId = jsonObject.getString("custom_id");
138-
JSONObject responseBody = jsonObject.getJSONObject("response").getJSONObject("body");
139-
JSONArray dataArray = responseBody.getJSONArray("data");
140-
Map<String, Object> jsonMap = processFieldMapping(customId, dataArray, mlBatchIngestionInput.getFieldMapping());
145+
Map<String, Object> jsonMap = processFieldMapping(jsonStr, mlBatchIngestionInput.getFieldMapping());
141146
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName()).source(jsonMap);
142147

143148
bulkRequest.add(indexRequest);
144149
});
145150
client.bulk(bulkRequest, bulkResponseListener);
146151
}
147152

148-
private Map<String, Object> processFieldMapping(String customId, JSONArray dataArray, Map<String, String> fieldMapping) {
153+
private Map<String, Object> processFieldMapping(String jsonStr, Map<String, Object> fieldMapping) {
154+
String outputJsonPath = (String) fieldMapping.get(OUTPUT);
155+
List<List> outputs = (List<List>) JsonPath.read(jsonStr, outputJsonPath);
156+
List<String> outputFields = (List<String>) fieldMapping.get(OUTPUTIELDS);
157+
List<String> ingestFieldsJsonPath = (List<String>) fieldMapping.get(INGESTFIELDS);
158+
149159
Map<String, Object> jsonMap = new HashMap<>();
150-
if (dataArray.length() == fieldMapping.size()) {
151-
int index = 0;
152-
for (Map.Entry<String, String> mapping : fieldMapping.entrySet()) {
153-
// key is the field name for input String, value is the field name for embedded output
154-
JSONObject dataItem = dataArray.getJSONObject(index);
155-
jsonMap.put(mapping.getValue(), dataItem.getJSONArray("embedding"));
156-
index++;
157-
}
158-
jsonMap.put("id", customId);
159-
} else {
160+
if (outputs.size() != outputFields.size()) {
160161
throw new IllegalArgumentException("the fieldMapping and source data do not match");
161162
}
163+
for (int index = 0; index < outputs.size();index++) {
164+
jsonMap.put(outputFields.get(index), outputs.get(index));
165+
}
166+
167+
for (String fieldPath : ingestFieldsJsonPath) {
168+
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
169+
}
170+
162171
return jsonMap;
163172
}
164173
}

0 commit comments

Comments
 (0)