Skip to content

Commit 39efbe7

Browse files
authored
Add Batch Prediction Mode in the Connector Framework for batch inference (opensearch-project#2661)
* add batch predict job actiontype in connector Signed-off-by: Xun Zhang <xunzh@amazon.com> * remove async and streaming mode temporarily Signed-off-by: Xun Zhang <xunzh@amazon.com> * rename predict mode to action type Signed-off-by: Xun Zhang <xunzh@amazon.com> * use method name in the url path for action type Signed-off-by: Xun Zhang <xunzh@amazon.com> * add stats for actionType and more UTs Signed-off-by: Xun Zhang <xunzh@amazon.com> * add bwx for actiontype Signed-off-by: Xun Zhang <xunzh@amazon.com> * address more comments Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent b980199 commit 39efbe7

File tree

17 files changed

+317
-27
lines changed

17 files changed

+317
-27
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -538,4 +538,5 @@ public class CommonValue {
538538
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
539539
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
540540
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
541+
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
541542
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+31-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
import org.opensearch.core.xcontent.ToXContentObject;
1515
import org.opensearch.core.xcontent.XContentBuilder;
1616
import org.opensearch.core.xcontent.XContentParser;
17+
import org.opensearch.ml.common.FunctionName;
1718

1819
import java.io.IOException;
20+
import java.util.HashSet;
1921
import java.util.Locale;
2022
import java.util.Map;
23+
import java.util.Set;
2124

2225
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
2326

@@ -183,6 +186,33 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
183186

184187
public enum ActionType {
185188
PREDICT,
186-
EXECUTE
189+
EXECUTE,
190+
BATCH_PREDICT;
191+
192+
public static ActionType from(String value) {
193+
try {
194+
return ActionType.valueOf(value.toUpperCase(Locale.ROOT));
195+
} catch (Exception e) {
196+
throw new IllegalArgumentException("Wrong Action Type of " + value);
197+
}
198+
}
199+
200+
private static final HashSet<ActionType> MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of(
201+
PREDICT,
202+
BATCH_PREDICT
203+
));
204+
205+
public static boolean isValidActionInModelPrediction(ActionType actionType) {
206+
return MODEL_SUPPORT_ACTIONS.contains(actionType);
207+
}
208+
209+
public static boolean isValidAction(String action) {
210+
try {
211+
ActionType.valueOf(action.toUpperCase());
212+
return true;
213+
} catch (IllegalArgumentException e) {
214+
return false;
215+
}
216+
}
187217
}
188218
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+29-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
import lombok.Builder;
99
import lombok.Getter;
1010
import lombok.Setter;
11+
import org.opensearch.Version;
1112
import org.opensearch.core.common.io.stream.StreamInput;
1213
import org.opensearch.core.common.io.stream.StreamOutput;
14+
import org.opensearch.ml.common.CommonValue;
15+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1316
import org.opensearch.ml.common.annotation.InputDataSet;
1417
import org.opensearch.ml.common.dataset.MLInputDataType;
1518
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -20,32 +23,56 @@
2023
@Getter
2124
@InputDataSet(MLInputDataType.REMOTE)
2225
public class RemoteInferenceInputDataSet extends MLInputDataset {
23-
26+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
2427
@Setter
2528
private Map<String, String> parameters;
29+
@Setter
30+
private ActionType actionType;
2631

2732
@Builder(toBuilder = true)
28-
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
33+
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
2934
super(MLInputDataType.REMOTE);
3035
this.parameters = parameters;
36+
this.actionType = actionType;
37+
}
38+
39+
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
40+
this(parameters, null);
3141
}
3242

3343
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
3444
super(MLInputDataType.REMOTE);
45+
Version streamInputVersion = streamInput.getVersion();
3546
if (streamInput.readBoolean()) {
3647
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
3748
}
49+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
50+
if (streamInput.readBoolean()) {
51+
actionType = streamInput.readEnum(ActionType.class);
52+
} else {
53+
this.actionType = null;
54+
}
55+
}
3856
}
3957

4058
@Override
4159
public void writeTo(StreamOutput streamOutput) throws IOException {
4260
super.writeTo(streamOutput);
61+
Version streamOutputVersion = streamOutput.getVersion();
4362
if (parameters != null) {
4463
streamOutput.writeBoolean(true);
4564
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
4665
} else {
4766
streamOutput.writeBoolean(false);
4867
}
68+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
69+
if (actionType != null) {
70+
streamOutput.writeBoolean(true);
71+
streamOutput.writeEnum(actionType);
72+
} else {
73+
streamOutput.writeBoolean(false);
74+
}
75+
}
4976
}
5077

5178
}

common/src/main/java/org/opensearch/ml/common/input/MLInput.java

+14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.opensearch.core.common.io.stream.StreamOutput;
1414
import org.opensearch.core.xcontent.XContentBuilder;
1515
import org.opensearch.core.xcontent.XContentParser;
16+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1617
import org.opensearch.ml.common.MLCommonsClassLoader;
1718
import org.opensearch.ml.common.dataframe.DataFrame;
1819
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
@@ -35,6 +36,7 @@
3536
import java.util.Map;
3637

3738
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
39+
import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD;
3840

3941
/**
4042
* ML input data: algorithm name, parameters and input data set.
@@ -196,6 +198,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196198
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset;
197199
Map<String, String> parameters = remoteInferenceInputDataSet.getParameters();
198200
builder.field(PARAMETERS_FIELD, parameters);
201+
builder.field(ACTION_TYPE_FIELD, remoteInferenceInputDataSet.getActionType());
199202
break;
200203
default:
201204
break;
@@ -206,6 +209,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
206209
return builder;
207210
}
208211

212+
public static MLInput parse(XContentParser parser, String inputAlgoName, ActionType actionType) throws IOException {
213+
MLInput mlInput = parse(parser, inputAlgoName);
214+
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
215+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)mlInput.getInputDataset();
216+
if (remoteInferenceInputDataSet.getActionType() == null) {
217+
remoteInferenceInputDataSet.setActionType(actionType);
218+
}
219+
}
220+
return mlInput;
221+
}
222+
209223
public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
210224
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
211225
FunctionName algorithm = FunctionName.from(algorithmName);

common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.core.common.io.stream.StreamOutput;
1010
import org.opensearch.core.xcontent.XContentParser;
1111
import org.opensearch.ml.common.FunctionName;
12+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1213
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1314
import org.opensearch.ml.common.input.MLInput;
1415
import org.opensearch.ml.common.utils.StringUtils;
@@ -21,6 +22,7 @@
2122
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
2223
public class RemoteInferenceMLInput extends MLInput {
2324
public static final String PARAMETERS_FIELD = "parameters";
25+
public static final String ACTION_TYPE_FIELD = "action_type";
2426

2527
public RemoteInferenceMLInput(StreamInput in) throws IOException {
2628
super(in);
@@ -34,21 +36,26 @@ public void writeTo(StreamOutput out) throws IOException {
3436
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
3537
super();
3638
this.algorithm = functionName;
39+
Map<String, String> parameters = null;
40+
ActionType actionType = null;
3741
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
3842
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
3943
String fieldName = parser.currentName();
4044
parser.nextToken();
4145

4246
switch (fieldName) {
4347
case PARAMETERS_FIELD:
44-
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
45-
inputDataset = new RemoteInferenceInputDataSet(parameters);
48+
parameters = StringUtils.getParameterMap(parser.map());
49+
break;
50+
case ACTION_TYPE_FIELD:
51+
actionType = ActionType.from(parser.text());
4652
break;
4753
default:
4854
parser.skipChildren();
4955
break;
5056
}
5157
}
58+
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
5259
}
5360

5461
}

common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java

+15
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.HashMap;
2525
import java.util.Map;
2626

27+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;
28+
2729
public class ConnectorActionTest {
2830
@Rule
2931
public ExpectedException exceptionRule = ExpectedException.none();
@@ -140,4 +142,17 @@ public void parse() throws IOException {
140142
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
141143
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
142144
}
145+
146+
@Test
147+
public void test_wrongActionType() {
148+
exceptionRule.expect(IllegalArgumentException.class);
149+
exceptionRule.expectMessage("Wrong Action Type");
150+
ConnectorAction.ActionType.from("badAction");
151+
}
152+
153+
@Test
154+
public void test_invalidActionInModelPrediction() {
155+
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
156+
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
157+
}
143158
}

common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java

+21
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.junit.Test;
55
import org.opensearch.common.io.stream.BytesStreamOutput;
66
import org.opensearch.core.common.io.stream.StreamInput;
7+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
78
import org.opensearch.ml.common.dataset.MLInputDataset;
89

910
import java.io.IOException;
@@ -45,4 +46,24 @@ public void writeTo() throws IOException {
4546
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
4647
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
4748
}
49+
50+
@Test
51+
public void writeTo_withActionType() throws IOException {
52+
Map<String, String> parameters = new HashMap<>();
53+
parameters.put("key1", "test value1");
54+
parameters.put("key2", "test value2");
55+
ActionType actionType = ActionType.from("predict");
56+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).actionType(actionType).build();
57+
58+
BytesStreamOutput output = new BytesStreamOutput();
59+
inputDataSet.writeTo(output);
60+
StreamInput streamInput = output.bytes().streamInput();
61+
62+
RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
63+
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
64+
Assert.assertEquals(2, inputDataSet2.getParameters().size());
65+
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
66+
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
67+
Assert.assertEquals("PREDICT", inputDataSet2.getActionType().toString());
68+
}
4869
}

common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java

+56
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121
import org.opensearch.core.xcontent.XContentParser;
2222
import org.opensearch.index.query.MatchAllQueryBuilder;
2323
import org.opensearch.ml.common.FunctionName;
24+
import org.opensearch.ml.common.connector.ConnectorAction;
2425
import org.opensearch.ml.common.dataframe.*;
2526
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
2627
import org.opensearch.ml.common.dataset.MLInputDataset;
2728
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
2829
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
2930
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
31+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
3032
import org.opensearch.ml.common.input.nlp.TextSimilarityMLInput;
3133
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
3234
import org.opensearch.ml.common.output.model.ModelResultFilter;
@@ -37,7 +39,9 @@
3739
import java.util.ArrayList;
3840
import java.util.Arrays;
3941
import java.util.Collections;
42+
import java.util.HashMap;
4043
import java.util.List;
44+
import java.util.Map;
4145
import java.util.function.Consumer;
4246
import java.util.function.Function;
4347

@@ -160,6 +164,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException {
160164
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
161165
}
162166

167+
@Test
168+
public void parse_Remote_Model() throws IOException {
169+
Map<String, String> parameters = Map.of("TransformJobName", "new name");
170+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
171+
.parameters(parameters)
172+
.actionType(ConnectorAction.ActionType.PREDICT)
173+
.build();
174+
175+
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}";
176+
177+
testParse(FunctionName.REMOTE, remoteInferenceInputDataSet, expectedInputStr, parsedInput -> {
178+
assertNotNull(parsedInput.getInputDataset());
179+
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
180+
assertEquals(ConnectorAction.ActionType.PREDICT, parsedInputDataSet.getActionType());
181+
});
182+
}
183+
184+
@Test
185+
public void parse_Remote_Model_With_ActionType() throws IOException {
186+
Map<String, String> parameters = Map.of("TransformJobName", "new name");
187+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
188+
.parameters(parameters)
189+
.actionType(ConnectorAction.ActionType.BATCH_PREDICT)
190+
.build();
191+
192+
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}";
193+
194+
testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> {
195+
assertNotNull(parsedInput.getInputDataset());
196+
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
197+
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType());
198+
});
199+
}
200+
163201
private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
164202
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
165203
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
@@ -178,6 +216,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri
178216
verify.accept(parsedInput);
179217
}
180218

219+
private void testParseWithActionType(FunctionName algorithm, MLInputDataset inputDataset, ConnectorAction.ActionType actionType, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
220+
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
221+
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
222+
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
223+
assertNotNull(builder);
224+
String jsonStr = builder.toString();
225+
assertEquals(expectedInputStr, jsonStr);
226+
227+
XContentParser parser = XContentType.JSON.xContent()
228+
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
229+
Collections.emptyList()).getNamedXContents()), null, jsonStr);
230+
parser.nextToken();
231+
MLInput parsedInput = MLInput.parse(parser, algorithm.name(), actionType);
232+
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
233+
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
234+
verify.accept(parsedInput);
235+
}
236+
181237
@Test
182238
public void readInputStream_Success() throws IOException {
183239
readInputStream(input, parsedInput -> {

common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ public void constructor_stream() throws IOException {
3939
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
4040
Assert.assertEquals(1, inputDataSet.getParameters().size());
4141
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
42+
Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString());
4243
}
4344

4445
private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
45-
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }";
46+
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }";
4647
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
4748
Collections.emptyList()).getNamedXContents()), null, jsonStr);
4849
parser.nextToken();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import org.opensearch.ml.common.FunctionName;
1919
import org.opensearch.ml.common.MLModel;
2020
import org.opensearch.ml.common.connector.Connector;
21+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
22+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
2123
import org.opensearch.ml.common.exception.MLException;
2224
import org.opensearch.ml.common.input.MLInput;
2325
import org.opensearch.ml.common.model.MLGuard;
@@ -70,7 +72,12 @@ public void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionL
7072
return;
7173
}
7274
try {
73-
connectorExecutor.executeAction(PREDICT.name(), mlInput, actionListener);
75+
ActionType actionType = null;
76+
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
77+
actionType = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getActionType();
78+
}
79+
actionType = actionType == null ? ActionType.PREDICT : actionType;
80+
connectorExecutor.executeAction(actionType.toString(), mlInput, actionListener);
7481
} catch (RuntimeException e) {
7582
log.error("Failed to call remote model.", e);
7683
actionListener.onFailure(e);

0 commit comments

Comments
 (0)