Skip to content

Commit 685cd31

Browse files
committed
add batch predict job actiontype in connector
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 818391b commit 685cd31

File tree

7 files changed

+48
-5
lines changed

7 files changed

+48
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package org.opensearch.ml.common;
2+
3+
import java.util.Locale;
4+
5+
public enum PredictMode {
6+
PREDICT,
7+
BATCH,
8+
ASYNC,
9+
STREAMING;
10+
11+
public static PredictMode from(String value) {
12+
try {
13+
return PredictMode.valueOf(value.toUpperCase(Locale.ROOT));
14+
} catch (Exception e) {
15+
throw new IllegalArgumentException("Wrong Predict mode");
16+
}
17+
}
18+
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
183183

184184
public enum ActionType {
185185
PREDICT,
186-
EXECUTE
186+
EXECUTE,
187+
BATCH
187188
}
188189
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import lombok.Setter;
1111
import org.opensearch.core.common.io.stream.StreamInput;
1212
import org.opensearch.core.common.io.stream.StreamOutput;
13+
import org.opensearch.ml.common.PredictMode;
1314
import org.opensearch.ml.common.annotation.InputDataSet;
1415
import org.opensearch.ml.common.dataset.MLInputDataType;
1516
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -23,11 +24,17 @@ public class RemoteInferenceInputDataSet extends MLInputDataset {
2324

2425
@Setter
2526
private Map<String, String> parameters;
27+
private PredictMode predictMode;
2628

2729
@Builder(toBuilder = true)
28-
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
30+
public RemoteInferenceInputDataSet(Map<String, String> parameters, PredictMode predictMode) {
2931
super(MLInputDataType.REMOTE);
3032
this.parameters = parameters;
33+
this.predictMode = predictMode;
34+
}
35+
36+
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
37+
this(parameters, null);
3138
}
3239

3340
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {

common/src/main/java/org/opensearch/ml/common/input/MLInput.java

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.Map;
3636

3737
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
38+
import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.PREDICT_MODE_FIELD;
3839

3940
/**
4041
* ML input data: algorithm name, parameters and input data set.
@@ -196,6 +197,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196197
RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset;
197198
Map<String, String> parameters = remoteInferenceInputDataSet.getParameters();
198199
builder.field(PARAMETERS_FIELD, parameters);
200+
builder.field(PREDICT_MODE_FIELD, remoteInferenceInputDataSet.getPredictMode());
199201
break;
200202
default:
201203
break;

common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.core.common.io.stream.StreamOutput;
1010
import org.opensearch.core.xcontent.XContentParser;
1111
import org.opensearch.ml.common.FunctionName;
12+
import org.opensearch.ml.common.PredictMode;
1213
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1314
import org.opensearch.ml.common.input.MLInput;
1415
import org.opensearch.ml.common.utils.StringUtils;
@@ -21,6 +22,7 @@
2122
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
2223
public class RemoteInferenceMLInput extends MLInput {
2324
public static final String PARAMETERS_FIELD = "parameters";
25+
public static final String PREDICT_MODE_FIELD = "mode";
2426

2527
public RemoteInferenceMLInput(StreamInput in) throws IOException {
2628
super(in);
@@ -34,21 +36,26 @@ public void writeTo(StreamOutput out) throws IOException {
3436
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
3537
super();
3638
this.algorithm = functionName;
39+
Map<String, String> parameters = null;
40+
PredictMode predictMode = null;
3741
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
3842
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
3943
String fieldName = parser.currentName();
4044
parser.nextToken();
4145

4246
switch (fieldName) {
4347
case PARAMETERS_FIELD:
44-
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
45-
inputDataset = new RemoteInferenceInputDataSet(parameters);
48+
parameters = StringUtils.getParameterMap(parser.map());
4649
break;
50+
case PREDICT_MODE_FIELD:
51+
predictMode = PredictMode.from(parser.text());
4752
default:
4853
parser.skipChildren();
4954
break;
5055
}
5156
}
57+
predictMode = predictMode == null? PredictMode.PREDICT:predictMode;
58+
inputDataset = new RemoteInferenceInputDataSet(parameters, predictMode);
5259
}
5360

5461
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import org.opensearch.core.xcontent.NamedXContentRegistry;
1818
import org.opensearch.ml.common.FunctionName;
1919
import org.opensearch.ml.common.MLModel;
20+
import org.opensearch.ml.common.PredictMode;
2021
import org.opensearch.ml.common.connector.Connector;
22+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
2123
import org.opensearch.ml.common.exception.MLException;
2224
import org.opensearch.ml.common.input.MLInput;
2325
import org.opensearch.ml.common.model.MLGuard;
@@ -70,7 +72,12 @@ public void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionL
7072
return;
7173
}
7274
try {
73-
connectorExecutor.executeAction(PREDICT.name(), mlInput, actionListener);
75+
PredictMode predictMode = null;
76+
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
77+
predictMode = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getPredictMode();
78+
}
79+
predictMode = predictMode == null?PredictMode.PREDICT:predictMode;
80+
connectorExecutor.executeAction(predictMode.toString(), mlInput, actionListener);
7481
} catch (RuntimeException e) {
7582
log.error("Failed to call remote model.", e);
7683
actionListener.onFailure(e);

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ private MLCommonsSettings() {}
132132
ImmutableList
133133
.of(
134134
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
135+
"^https://api\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
135136
"^https://api\\.openai\\.com/.*$",
136137
"^https://api\\.cohere\\.ai/.*$",
137138
"^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"

0 commit comments

Comments
 (0)