Skip to content

Commit c4b1696

Browse files
committed
Add Batch Prediction Mode in the Connector Framework for batch inference (opensearch-project#2661)
* add batch predict job actiontype in connector Signed-off-by: Xun Zhang <xunzh@amazon.com> * remove async and streaming mode temporarily Signed-off-by: Xun Zhang <xunzh@amazon.com> * rename predict mode to action type Signed-off-by: Xun Zhang <xunzh@amazon.com> * use method name in the url path for action type Signed-off-by: Xun Zhang <xunzh@amazon.com> * add stats for actionType and more UTs Signed-off-by: Xun Zhang <xunzh@amazon.com> * add bwx for actiontype Signed-off-by: Xun Zhang <xunzh@amazon.com> * address more comments Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 9b2e5f1 commit c4b1696

File tree

19 files changed

+314
-31
lines changed

19 files changed

+314
-31
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -538,4 +538,5 @@ public class CommonValue {
538538
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
539539
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
540540
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
541+
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
541542
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+31-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
import org.opensearch.core.xcontent.ToXContentObject;
1515
import org.opensearch.core.xcontent.XContentBuilder;
1616
import org.opensearch.core.xcontent.XContentParser;
17+
import org.opensearch.ml.common.FunctionName;
1718

1819
import java.io.IOException;
20+
import java.util.HashSet;
1921
import java.util.Locale;
2022
import java.util.Map;
23+
import java.util.Set;
2124

2225
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
2326

@@ -183,6 +186,33 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
183186

184187
public enum ActionType {
185188
PREDICT,
186-
EXECUTE
189+
EXECUTE,
190+
BATCH_PREDICT;
191+
192+
public static ActionType from(String value) {
193+
try {
194+
return ActionType.valueOf(value.toUpperCase(Locale.ROOT));
195+
} catch (Exception e) {
196+
throw new IllegalArgumentException("Wrong Action Type of " + value);
197+
}
198+
}
199+
200+
private static final HashSet<ActionType> MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of(
201+
PREDICT,
202+
BATCH_PREDICT
203+
));
204+
205+
public static boolean isValidActionInModelPrediction(ActionType actionType) {
206+
return MODEL_SUPPORT_ACTIONS.contains(actionType);
207+
}
208+
209+
public static boolean isValidAction(String action) {
210+
try {
211+
ActionType.valueOf(action.toUpperCase());
212+
return true;
213+
} catch (IllegalArgumentException e) {
214+
return false;
215+
}
216+
}
187217
}
188218
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+29-2
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,68 @@
1010
import lombok.Builder;
1111
import lombok.Getter;
1212
import lombok.Setter;
13+
import org.opensearch.Version;
1314
import org.opensearch.core.common.io.stream.StreamInput;
1415
import org.opensearch.core.common.io.stream.StreamOutput;
16+
import org.opensearch.ml.common.CommonValue;
17+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1518
import org.opensearch.ml.common.annotation.InputDataSet;
1619
import org.opensearch.ml.common.dataset.MLInputDataType;
1720
import org.opensearch.ml.common.dataset.MLInputDataset;
1821

1922
@Getter
2023
@InputDataSet(MLInputDataType.REMOTE)
2124
public class RemoteInferenceInputDataSet extends MLInputDataset {
22-
25+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
2326
@Setter
2427
private Map<String, String> parameters;
28+
@Setter
29+
private ActionType actionType;
2530

2631
@Builder(toBuilder = true)
27-
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
32+
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
2833
super(MLInputDataType.REMOTE);
2934
this.parameters = parameters;
35+
this.actionType = actionType;
36+
}
37+
38+
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
39+
this(parameters, null);
3040
}
3141

3242
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
3343
super(MLInputDataType.REMOTE);
44+
Version streamInputVersion = streamInput.getVersion();
3445
if (streamInput.readBoolean()) {
3546
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
3647
}
48+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
49+
if (streamInput.readBoolean()) {
50+
actionType = streamInput.readEnum(ActionType.class);
51+
} else {
52+
this.actionType = null;
53+
}
54+
}
3755
}
3856

3957
@Override
4058
public void writeTo(StreamOutput streamOutput) throws IOException {
4159
super.writeTo(streamOutput);
60+
Version streamOutputVersion = streamOutput.getVersion();
4261
if (parameters != null) {
4362
streamOutput.writeBoolean(true);
4463
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
4564
} else {
4665
streamOutput.writeBoolean(false);
4766
}
67+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
68+
if (actionType != null) {
69+
streamOutput.writeBoolean(true);
70+
streamOutput.writeEnum(actionType);
71+
} else {
72+
streamOutput.writeBoolean(false);
73+
}
74+
}
4875
}
4976

5077
}

common/src/main/java/org/opensearch/ml/common/input/MLInput.java

+14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.opensearch.core.common.io.stream.StreamOutput;
1414
import org.opensearch.core.xcontent.XContentBuilder;
1515
import org.opensearch.core.xcontent.XContentParser;
16+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1617
import org.opensearch.ml.common.MLCommonsClassLoader;
1718
import org.opensearch.ml.common.dataframe.DataFrame;
1819
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
@@ -35,6 +36,7 @@
3536
import java.util.Map;
3637

3738
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
39+
import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD;
3840

3941
/**
4042
* ML input data: algorithm name, parameters and input data set.
@@ -196,6 +198,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196198
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset;
197199
Map<String, String> parameters = remoteInferenceInputDataSet.getParameters();
198200
builder.field(PARAMETERS_FIELD, parameters);
201+
builder.field(ACTION_TYPE_FIELD, remoteInferenceInputDataSet.getActionType());
199202
break;
200203
default:
201204
break;
@@ -206,6 +209,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
206209
return builder;
207210
}
208211

212+
public static MLInput parse(XContentParser parser, String inputAlgoName, ActionType actionType) throws IOException {
213+
MLInput mlInput = parse(parser, inputAlgoName);
214+
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
215+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)mlInput.getInputDataset();
216+
if (remoteInferenceInputDataSet.getActionType() == null) {
217+
remoteInferenceInputDataSet.setActionType(actionType);
218+
}
219+
}
220+
return mlInput;
221+
}
222+
209223
public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
210224
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
211225
FunctionName algorithm = FunctionName.from(algorithmName);

common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.core.common.io.stream.StreamOutput;
1010
import org.opensearch.core.xcontent.XContentParser;
1111
import org.opensearch.ml.common.FunctionName;
12+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1213
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1314
import org.opensearch.ml.common.input.MLInput;
1415
import org.opensearch.ml.common.utils.StringUtils;
@@ -21,6 +22,7 @@
2122
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
2223
public class RemoteInferenceMLInput extends MLInput {
2324
public static final String PARAMETERS_FIELD = "parameters";
25+
public static final String ACTION_TYPE_FIELD = "action_type";
2426

2527
public RemoteInferenceMLInput(StreamInput in) throws IOException {
2628
super(in);
@@ -34,21 +36,26 @@ public void writeTo(StreamOutput out) throws IOException {
3436
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
3537
super();
3638
this.algorithm = functionName;
39+
Map<String, String> parameters = null;
40+
ActionType actionType = null;
3741
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
3842
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
3943
String fieldName = parser.currentName();
4044
parser.nextToken();
4145

4246
switch (fieldName) {
4347
case PARAMETERS_FIELD:
44-
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
45-
inputDataset = new RemoteInferenceInputDataSet(parameters);
48+
parameters = StringUtils.getParameterMap(parser.map());
49+
break;
50+
case ACTION_TYPE_FIELD:
51+
actionType = ActionType.from(parser.text());
4652
break;
4753
default:
4854
parser.skipChildren();
4955
break;
5056
}
5157
}
58+
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
5259
}
5360

5461
}

common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java

+15
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.HashMap;
2525
import java.util.Map;
2626

27+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;
28+
2729
public class ConnectorActionTest {
2830
@Rule
2931
public ExpectedException exceptionRule = ExpectedException.none();
@@ -140,4 +142,17 @@ public void parse() throws IOException {
140142
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
141143
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
142144
}
145+
146+
@Test
147+
public void test_wrongActionType() {
148+
exceptionRule.expect(IllegalArgumentException.class);
149+
exceptionRule.expectMessage("Wrong Action Type");
150+
ConnectorAction.ActionType.from("badAction");
151+
}
152+
153+
@Test
154+
public void test_invalidActionInModelPrediction() {
155+
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
156+
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
157+
}
143158
}

common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java

+21
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.junit.Test;
1010
import org.opensearch.common.io.stream.BytesStreamOutput;
1111
import org.opensearch.core.common.io.stream.StreamInput;
12+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1213
import org.opensearch.ml.common.dataset.MLInputDataset;
1314

1415
public class RemoteInferenceInputDataSetTest {
@@ -44,4 +45,24 @@ public void writeTo() throws IOException {
4445
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
4546
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
4647
}
48+
49+
@Test
50+
public void writeTo_withActionType() throws IOException {
51+
Map<String, String> parameters = new HashMap<>();
52+
parameters.put("key1", "test value1");
53+
parameters.put("key2", "test value2");
54+
ActionType actionType = ActionType.from("predict");
55+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).actionType(actionType).build();
56+
57+
BytesStreamOutput output = new BytesStreamOutput();
58+
inputDataSet.writeTo(output);
59+
StreamInput streamInput = output.bytes().streamInput();
60+
61+
RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
62+
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
63+
Assert.assertEquals(2, inputDataSet2.getParameters().size());
64+
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
65+
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
66+
Assert.assertEquals("PREDICT", inputDataSet2.getActionType().toString());
67+
}
4768
}

common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java

+56-1
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
2828
import org.opensearch.ml.common.dataframe.DoubleValue;
2929
import org.opensearch.ml.common.dataframe.Row;
30-
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
3130
import org.opensearch.ml.common.FunctionName;
31+
import org.opensearch.ml.common.connector.ConnectorAction;
32+
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
3233
import org.opensearch.ml.common.dataset.MLInputDataset;
3334
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
3435
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
3536
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
37+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
3638
import org.opensearch.ml.common.input.nlp.TextSimilarityMLInput;
3739
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
3840
import org.opensearch.ml.common.output.model.ModelResultFilter;
@@ -44,6 +46,7 @@
4446
import java.util.Arrays;
4547
import java.util.Collections;
4648
import java.util.List;
49+
import java.util.Map;
4750
import java.util.function.Consumer;
4851
import java.util.function.Function;
4952

@@ -168,6 +171,40 @@ public void parse_NLPRelated_NullResultFilter() throws IOException {
168171
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
169172
}
170173

174+
@Test
175+
public void parse_Remote_Model() throws IOException {
176+
Map<String, String> parameters = Map.of("TransformJobName", "new name");
177+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
178+
.parameters(parameters)
179+
.actionType(ConnectorAction.ActionType.PREDICT)
180+
.build();
181+
182+
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}";
183+
184+
testParse(FunctionName.REMOTE, remoteInferenceInputDataSet, expectedInputStr, parsedInput -> {
185+
assertNotNull(parsedInput.getInputDataset());
186+
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
187+
assertEquals(ConnectorAction.ActionType.PREDICT, parsedInputDataSet.getActionType());
188+
});
189+
}
190+
191+
@Test
192+
public void parse_Remote_Model_With_ActionType() throws IOException {
193+
Map<String, String> parameters = Map.of("TransformJobName", "new name");
194+
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder()
195+
.parameters(parameters)
196+
.actionType(ConnectorAction.ActionType.BATCH_PREDICT)
197+
.build();
198+
199+
String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}";
200+
201+
testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> {
202+
assertNotNull(parsedInput.getInputDataset());
203+
RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset();
204+
assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType());
205+
});
206+
}
207+
171208
private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
172209
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
173210
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
@@ -186,6 +223,24 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri
186223
verify.accept(parsedInput);
187224
}
188225

226+
private void testParseWithActionType(FunctionName algorithm, MLInputDataset inputDataset, ConnectorAction.ActionType actionType, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
227+
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
228+
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
229+
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
230+
assertNotNull(builder);
231+
String jsonStr = builder.toString();
232+
assertEquals(expectedInputStr, jsonStr);
233+
234+
XContentParser parser = XContentType.JSON.xContent()
235+
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
236+
Collections.emptyList()).getNamedXContents()), null, jsonStr);
237+
parser.nextToken();
238+
MLInput parsedInput = MLInput.parse(parser, algorithm.name(), actionType);
239+
assertEquals(input.getFunctionName(), parsedInput.getFunctionName());
240+
assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType());
241+
verify.accept(parsedInput);
242+
}
243+
189244
@Test
190245
public void readInputStream_Success() throws IOException {
191246
readInputStream(input, parsedInput -> {

common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ public void constructor_stream() throws IOException {
3939
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
4040
Assert.assertEquals(1, inputDataSet.getParameters().size());
4141
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
42+
Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString());
4243
}
4344

4445
private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
45-
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }";
46+
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }";
4647
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
4748
Collections.emptyList()).getNamedXContents()), null, jsonStr);
4849
parser.nextToken();

memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java

-1
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,6 @@ public void testGetSg_NoIndex_ThenFail() {
750750
interactionsIndex.getInteraction("iid", getListener);
751751
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
752752
verify(getListener, times(1)).onFailure(argCaptor.capture());
753-
System.out.println(argCaptor.getValue().getMessage());
754753
assert (argCaptor
755754
.getValue()
756755
.getMessage()

0 commit comments

Comments
 (0)