Skip to content

Commit dcded30

Browse files
committed
batch ingest API rest and transport actions
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 2a33c65 commit dcded30

File tree

17 files changed

+877
-5
lines changed

17 files changed

+877
-5
lines changed

common/src/main/java/org/opensearch/ml/common/MLTaskType.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ public enum MLTaskType {
1515
@Deprecated
1616
LOAD_MODEL,
1717
REGISTER_MODEL,
18-
DEPLOY_MODEL
18+
DEPLOY_MODEL,
19+
BATCH_INGEST
1920
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package org.opensearch.ml.common.transport.batch;
2+
3+
import org.opensearch.action.ActionType;
4+
5+
public class MLBatchIngestionAction extends ActionType<MLBatchIngestionResponse> {
6+
public static MLBatchIngestionAction INSTANCE = new MLBatchIngestionAction();
7+
public static final String NAME = "cluster:admin/opensearch/ml/batch_ingestion";
8+
9+
private MLBatchIngestionAction() {
10+
super(NAME, MLBatchIngestionResponse::new);
11+
}
12+
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.batch;
7+
8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.utils.StringUtils.getOrderedMap;
10+
11+
import java.io.IOException;
12+
import java.util.HashMap;
13+
import java.util.Map;
14+
15+
import org.opensearch.core.common.io.stream.StreamInput;
16+
import org.opensearch.core.common.io.stream.StreamOutput;
17+
import org.opensearch.core.common.io.stream.Writeable;
18+
import org.opensearch.core.xcontent.ToXContentObject;
19+
import org.opensearch.core.xcontent.XContentBuilder;
20+
import org.opensearch.core.xcontent.XContentParser;
21+
22+
import lombok.Builder;
23+
import lombok.Getter;
24+
25+
/**
26+
* ML batch ingestion data: index, field mapping and input and out files.
27+
*/
28+
public class MLBatchIngestionInput implements ToXContentObject, Writeable {
29+
30+
public static final String INDEX_NAME_FIELD = "index_name";
31+
public static final String TEXT_EMBEDDING_FIELD_MAP_FIELD = "text_embedding_field_map";
32+
public static final String DATA_SOURCE_FIELD = "data_source";
33+
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
34+
@Getter
35+
private String indexName;
36+
@Getter
37+
private Map<String, String> fieldMapping;
38+
@Getter
39+
private Map<String, String> dataSources;
40+
@Getter
41+
private Map<String, String> credential;
42+
43+
@Builder(toBuilder = true)
44+
public MLBatchIngestionInput(
45+
String indexName,
46+
Map<String, String> fieldMapping,
47+
Map<String, String> dataSources,
48+
Map<String, String> credential
49+
) {
50+
this.indexName = indexName;
51+
this.fieldMapping = fieldMapping;
52+
this.dataSources = dataSources;
53+
this.credential = credential;
54+
}
55+
56+
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
57+
String indexName = null;
58+
Map<String, String> fieldMapping = null;
59+
Map<String, String> dataSources = null;
60+
Map<String, String> credential = new HashMap<>();
61+
62+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
63+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
64+
String fieldName = parser.currentName();
65+
parser.nextToken();
66+
67+
switch (fieldName) {
68+
case INDEX_NAME_FIELD:
69+
indexName = parser.text();
70+
break;
71+
case TEXT_EMBEDDING_FIELD_MAP_FIELD:
72+
fieldMapping = getOrderedMap(parser.mapOrdered());
73+
break;
74+
case CONNECTOR_CREDENTIAL_FIELD:
75+
credential = parser.mapStrings();
76+
break;
77+
case DATA_SOURCE_FIELD:
78+
dataSources = parser.mapStrings();
79+
break;
80+
default:
81+
parser.skipChildren();
82+
break;
83+
}
84+
}
85+
return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential);
86+
}
87+
88+
@Override
89+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
90+
builder.startObject();
91+
if (indexName != null) {
92+
builder.field(INDEX_NAME_FIELD, indexName);
93+
}
94+
if (fieldMapping != null) {
95+
builder.field(TEXT_EMBEDDING_FIELD_MAP_FIELD, fieldMapping);
96+
}
97+
if (dataSources != null) {
98+
builder.field(DATA_SOURCE_FIELD, dataSources);
99+
}
100+
if (credential != null) {
101+
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
102+
}
103+
builder.endObject();
104+
return builder;
105+
}
106+
107+
@Override
108+
public void writeTo(StreamOutput output) throws IOException {
109+
output.writeOptionalString(indexName);
110+
if (fieldMapping != null) {
111+
output.writeBoolean(true);
112+
output.writeMap(fieldMapping, StreamOutput::writeString, StreamOutput::writeString);
113+
} else {
114+
output.writeBoolean(false);
115+
}
116+
117+
if (dataSources != null) {
118+
output.writeBoolean(true);
119+
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeString);
120+
} else {
121+
output.writeBoolean(false);
122+
}
123+
124+
if (credential != null) {
125+
output.writeBoolean(true);
126+
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
127+
} else {
128+
output.writeBoolean(false);
129+
}
130+
}
131+
132+
public MLBatchIngestionInput(StreamInput input) throws IOException {
133+
indexName = input.readOptionalString();
134+
if (input.readBoolean()) {
135+
fieldMapping = input.readMap(s -> s.readString(), s -> s.readString());
136+
}
137+
if (input.readBoolean()) {
138+
dataSources = input.readMap(s -> s.readString(), s -> s.readString());
139+
}
140+
if (input.readBoolean()) {
141+
credential = input.readMap(s -> s.readString(), s -> s.readString());
142+
}
143+
}
144+
145+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package org.opensearch.ml.common.transport.batch;
2+
3+
import static org.opensearch.action.ValidateActions.addValidationError;
4+
5+
import java.io.ByteArrayInputStream;
6+
import java.io.ByteArrayOutputStream;
7+
import java.io.IOException;
8+
import java.io.UncheckedIOException;
9+
10+
import org.opensearch.action.ActionRequest;
11+
import org.opensearch.action.ActionRequestValidationException;
12+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
13+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
14+
import org.opensearch.core.common.io.stream.StreamInput;
15+
16+
import lombok.AccessLevel;
17+
import lombok.Builder;
18+
import lombok.Getter;
19+
import lombok.ToString;
20+
import lombok.experimental.FieldDefaults;
21+
22+
@Getter
23+
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
24+
@ToString
25+
public class MLBatchIngestionRequest extends ActionRequest {
26+
27+
private MLBatchIngestionInput mlBatchIngestionInput;
28+
29+
@Builder
30+
public MLBatchIngestionRequest(MLBatchIngestionInput mlBatchIngestionInput) {
31+
this.mlBatchIngestionInput = mlBatchIngestionInput;
32+
}
33+
34+
public MLBatchIngestionRequest(StreamInput in) throws IOException {
35+
super(in);
36+
this.mlBatchIngestionInput = new MLBatchIngestionInput(in);
37+
}
38+
39+
@Override
40+
public ActionRequestValidationException validate() {
41+
ActionRequestValidationException exception = null;
42+
if (mlBatchIngestionInput == null) {
43+
exception = addValidationError("ML batch ingestion input can't be null", exception);
44+
}
45+
if (mlBatchIngestionInput.getCredential() == null) {
46+
exception = addValidationError("ML batch ingestion credentials can't be null", exception);
47+
}
48+
if (mlBatchIngestionInput.getDataSources() == null) {
49+
exception = addValidationError("ML batch ingestion data sources can't be null", exception);
50+
}
51+
52+
return exception;
53+
}
54+
55+
public static MLBatchIngestionRequest fromActionRequest(ActionRequest actionRequest) {
56+
if (actionRequest instanceof MLBatchIngestionRequest) {
57+
return (MLBatchIngestionRequest) actionRequest;
58+
}
59+
60+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
61+
actionRequest.writeTo(osso);
62+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
63+
return new MLBatchIngestionRequest(input);
64+
}
65+
} catch (IOException e) {
66+
throw new UncheckedIOException("failed to parse ActionRequest into MLBatchIngestionRequest", e);
67+
}
68+
69+
}
70+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package org.opensearch.ml.common.transport.batch;
2+
3+
import java.io.ByteArrayInputStream;
4+
import java.io.ByteArrayOutputStream;
5+
import java.io.IOException;
6+
import java.io.UncheckedIOException;
7+
8+
import org.opensearch.core.action.ActionResponse;
9+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
10+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
11+
import org.opensearch.core.common.io.stream.StreamInput;
12+
import org.opensearch.core.common.io.stream.StreamOutput;
13+
import org.opensearch.core.xcontent.ToXContent;
14+
import org.opensearch.core.xcontent.ToXContentObject;
15+
import org.opensearch.core.xcontent.XContentBuilder;
16+
import org.opensearch.ml.common.MLTaskType;
17+
18+
public class MLBatchIngestionResponse extends ActionResponse implements ToXContentObject {
19+
public static final String TASK_ID_FIELD = "task_id";
20+
public static final String TASK_TYPE_FIELD = "task_type";
21+
public static final String STATUS_FIELD = "status";
22+
23+
private String taskId;
24+
private MLTaskType taskType;
25+
private String status;
26+
27+
public MLBatchIngestionResponse(StreamInput in) throws IOException {
28+
super(in);
29+
this.taskId = in.readString();
30+
this.taskType = in.readEnum(MLTaskType.class);
31+
this.status = in.readString();
32+
}
33+
34+
public MLBatchIngestionResponse(String taskId, MLTaskType mlTaskType, String status) {
35+
this.taskId = taskId;
36+
this.taskType = mlTaskType;
37+
this.status = status;
38+
}
39+
40+
@Override
41+
public void writeTo(StreamOutput out) throws IOException {
42+
out.writeString(taskId);
43+
out.writeEnum(taskType);
44+
out.writeString(status);
45+
}
46+
47+
@Override
48+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
49+
builder.startObject();
50+
builder.field(TASK_ID_FIELD, taskId);
51+
if (taskType != null) {
52+
builder.field(TASK_TYPE_FIELD, taskType);
53+
}
54+
builder.field(STATUS_FIELD, status);
55+
builder.endObject();
56+
return builder;
57+
}
58+
59+
public static MLBatchIngestionResponse fromActionResponse(ActionResponse actionResponse) {
60+
if (actionResponse instanceof MLBatchIngestionResponse) {
61+
return (MLBatchIngestionResponse) actionResponse;
62+
}
63+
64+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
65+
actionResponse.writeTo(osso);
66+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
67+
return new MLBatchIngestionResponse(input);
68+
}
69+
} catch (IOException e) {
70+
throw new UncheckedIOException("failed to parse ActionResponse into MLBatchIngestionResponse", e);
71+
}
72+
}
73+
}

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

+22
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.ArrayList;
1414
import java.util.HashMap;
1515
import java.util.HashSet;
16+
import java.util.LinkedHashMap;
1617
import java.util.List;
1718
import java.util.Map;
1819
import java.util.Set;
@@ -142,6 +143,27 @@ public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs)
142143
return parameters;
143144
}
144145

146+
@SuppressWarnings("removal")
147+
public static LinkedHashMap<String, String> getOrderedMap(Map<String, ?> parameterObjs) {
148+
LinkedHashMap<String, String> parameters = new LinkedHashMap<>();
149+
for (String key : parameterObjs.keySet()) {
150+
Object value = parameterObjs.get(key);
151+
try {
152+
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
153+
if (value instanceof String) {
154+
parameters.put(key, (String) value);
155+
} else {
156+
parameters.put(key, gson.toJson(value));
157+
}
158+
return null;
159+
});
160+
} catch (PrivilegedActionException e) {
161+
throw new RuntimeException(e);
162+
}
163+
}
164+
return parameters;
165+
}
166+
145167
@SuppressWarnings("removal")
146168
public static String toJson(Object value) {
147169
try {

ml-algorithms/build.gradle

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import org.gradle.nativeplatform.platform.internal.DefaultNativePlatform
77

88
plugins {
99
id 'java'
10+
id 'java-library'
1011
id 'jacoco'
1112
id "io.freefair.lombok"
1213
id 'com.diffplug.spotless' version '6.25.0'
@@ -62,9 +63,12 @@ dependencies {
6263
}
6364

6465
implementation platform('software.amazon.awssdk:bom:2.25.40')
65-
implementation 'software.amazon.awssdk:auth'
66+
api 'software.amazon.awssdk:auth:2.25.40'
6667
implementation 'software.amazon.awssdk:apache-client'
6768
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
69+
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.25.40'
70+
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.25.40'
71+
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.25.40'
6872
implementation 'com.jayway.jsonpath:json-path:2.9.0'
6973
implementation group: 'org.json', name: 'json', version: '20231013'
7074
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.25.40'

0 commit comments

Comments
 (0)