Skip to content

Commit 85df45a

Browse files
committed
add batch predict job actiontype in connector
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 865a424 commit 85df45a

File tree

5 files changed

+38
-4
lines changed

5 files changed

+38
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package org.opensearch.ml.common;
2+
3+
import java.util.Locale;
4+
5+
public enum PredictMode {
6+
REAL_TIME,
7+
BATCH,
8+
ASYNC,
9+
STREAMING;
10+
11+
public static PredictMode from(String value) {
12+
try {
13+
return PredictMode.valueOf(value.toUpperCase(Locale.ROOT));
14+
} catch (Exception e) {
15+
throw new IllegalArgumentException("Wrong Predict mode");
16+
}
17+
}
18+
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
182182
}
183183

184184
public enum ActionType {
185-
PREDICT
185+
PREDICT,
186+
BATCH_PREDICT_JOB
186187
}
187188
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import lombok.Setter;
1111
import org.opensearch.core.common.io.stream.StreamInput;
1212
import org.opensearch.core.common.io.stream.StreamOutput;
13+
import org.opensearch.ml.common.PredictMode;
1314
import org.opensearch.ml.common.annotation.InputDataSet;
1415
import org.opensearch.ml.common.dataset.MLInputDataType;
1516
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -23,11 +24,17 @@ public class RemoteInferenceInputDataSet extends MLInputDataset {
2324

2425
@Setter
2526
private Map<String, String> parameters;
27+
private PredictMode predictMode;
2628

2729
@Builder(toBuilder = true)
28-
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
30+
public RemoteInferenceInputDataSet(Map<String, String> parameters, PredictMode predictMode) {
2931
super(MLInputDataType.REMOTE);
3032
this.parameters = parameters;
33+
this.predictMode = predictMode;
34+
}
35+
36+
public RemoteInferenceInputDataSet(Map<String, String> parameters) {
37+
this(parameters, null);
3138
}
3239

3340
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {

common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.core.common.io.stream.StreamOutput;
1010
import org.opensearch.core.xcontent.XContentParser;
1111
import org.opensearch.ml.common.FunctionName;
12+
import org.opensearch.ml.common.PredictMode;
1213
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1314
import org.opensearch.ml.common.input.MLInput;
1415
import org.opensearch.ml.common.utils.StringUtils;
@@ -21,6 +22,7 @@
2122
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
2223
public class RemoteInferenceMLInput extends MLInput {
2324
public static final String PARAMETERS_FIELD = "parameters";
25+
public static final String PREDICT_MODE_FIELD = "mode";
2426

2527
public RemoteInferenceMLInput(StreamInput in) throws IOException {
2628
super(in);
@@ -34,21 +36,26 @@ public void writeTo(StreamOutput out) throws IOException {
3436
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
3537
super();
3638
this.algorithm = functionName;
39+
Map<String, String> parameters = null;
40+
PredictMode predictMode = null;
3741
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
3842
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
3943
String fieldName = parser.currentName();
4044
parser.nextToken();
4145

4246
switch (fieldName) {
4347
case PARAMETERS_FIELD:
44-
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
45-
inputDataset = new RemoteInferenceInputDataSet(parameters);
48+
parameters = StringUtils.getParameterMap(parser.map());
4649
break;
50+
case PREDICT_MODE_FIELD:
51+
predictMode = PredictMode.from(parser.text());
4752
default:
4853
parser.skipChildren();
4954
break;
5055
}
5156
}
57+
predictMode = predictMode == null? PredictMode.REAL_TIME:predictMode;
58+
inputDataset = new RemoteInferenceInputDataSet(parameters, predictMode);
5259
}
5360

5461
}

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ private MLCommonsSettings() {}
132132
ImmutableList
133133
.of(
134134
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
135+
"^https://api\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
135136
"^https://api\\.openai\\.com/.*$",
136137
"^https://api\\.cohere\\.ai/.*$",
137138
"^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$"

0 commit comments

Comments
 (0)