Skip to content

Commit 6abfd54

Browse files
committed
add bwx for actiontype
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent b98fe08 commit 6abfd54

File tree

5 files changed

+35
-15
lines changed

5 files changed

+35
-15
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1
Original file line numberDiff line numberDiff line change
@@ -537,4 +537,5 @@ public class CommonValue {
537537
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
538538
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
539539
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
540+
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
540541
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ public static ActionType from(String value) {
193193
try {
194194
return ActionType.valueOf(value.toUpperCase(Locale.ROOT));
195195
} catch (Exception e) {
196-
throw new IllegalArgumentException("Wrong Action Type");
196+
throw new IllegalArgumentException("Wrong Action Type of " + value);
197197
}
198198
}
199199

@@ -205,5 +205,14 @@ public static ActionType from(String value) {
205205
public static boolean isValidActionInModelPrediction(ActionType actionType) {
206206
return MODEL_SUPPORT_ACTIONS.contains(actionType);
207207
}
208+
209+
public static boolean isValidAction(String action) {
210+
try {
211+
ActionType.valueOf(action.toUpperCase());
212+
return true;
213+
} catch (IllegalArgumentException e) {
214+
return false;
215+
}
216+
}
208217
}
209218
}

common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java

+18-10
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import lombok.Builder;
99
import lombok.Getter;
1010
import lombok.Setter;
11+
import org.opensearch.Version;
1112
import org.opensearch.core.common.io.stream.StreamInput;
1213
import org.opensearch.core.common.io.stream.StreamOutput;
14+
import org.opensearch.ml.common.CommonValue;
1315
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
1416
import org.opensearch.ml.common.annotation.InputDataSet;
1517
import org.opensearch.ml.common.dataset.MLInputDataType;
@@ -21,7 +23,7 @@
2123
@Getter
2224
@InputDataSet(MLInputDataType.REMOTE)
2325
public class RemoteInferenceInputDataSet extends MLInputDataset {
24-
26+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
2527
@Setter
2628
private Map<String, String> parameters;
2729
@Setter
@@ -40,30 +42,36 @@ public RemoteInferenceInputDataSet(Map<String, String> parameters) {
4042

4143
public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
4244
super(MLInputDataType.REMOTE);
45+
Version streamInputVersion = streamInput.getVersion();
4346
if (streamInput.readBoolean()) {
4447
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
4548
}
46-
if (streamInput.readBoolean()) {
47-
actionType = streamInput.readEnum(ActionType.class);
48-
} else {
49-
this.actionType = null;
49+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
50+
if (streamInput.readBoolean()) {
51+
actionType = streamInput.readEnum(ActionType.class);
52+
} else {
53+
this.actionType = null;
54+
}
5055
}
5156
}
5257

5358
@Override
5459
public void writeTo(StreamOutput streamOutput) throws IOException {
5560
super.writeTo(streamOutput);
61+
Version streamOutputVersion = streamOutput.getVersion();
5662
if (parameters != null) {
5763
streamOutput.writeBoolean(true);
5864
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
5965
} else {
6066
streamOutput.writeBoolean(false);
6167
}
62-
if (actionType != null) {
63-
streamOutput.writeBoolean(true);
64-
streamOutput.writeEnum(actionType);
65-
} else {
66-
streamOutput.writeBoolean(false);
68+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
69+
if (actionType != null) {
70+
streamOutput.writeBoolean(true);
71+
streamOutput.writeEnum(actionType);
72+
} else {
73+
streamOutput.writeBoolean(false);
74+
}
6775
}
6876
}
6977

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import java.util.Arrays;
2020
import java.util.UUID;
2121

22-
import javax.swing.*;
23-
2422
import org.opensearch.OpenSearchException;
2523
import org.opensearch.OpenSearchStatusException;
2624
import org.opensearch.ResourceNotFoundException;

plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.opensearch.core.common.Strings;
4040
import org.opensearch.core.rest.RestStatus;
4141
import org.opensearch.index.IndexNotFoundException;
42+
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
4243
import org.opensearch.rest.BytesRestResponse;
4344
import org.opensearch.rest.RestChannel;
4445
import org.opensearch.rest.RestRequest;
@@ -313,8 +314,11 @@ public static String getActionTypeFromRestRequest(RestRequest request) {
313314
String path = request.path();
314315
String[] segments = path.split("/");
315316
String methodName = segments[segments.length - 1];
316-
if (methodName.contains("_")) {
317-
methodName = methodName.split("_")[1];
317+
methodName = methodName.contains("_") ? methodName.split("_")[1] : methodName;
318+
319+
// find the action type for "/_plugins/_ml/_predict/<algorithm>/<model_id>"
320+
if (!ActionType.isValidAction(methodName) && segments.length > 3) {
321+
methodName = segments[3];
318322
}
319323
return methodName;
320324
}

0 commit comments

Comments
 (0)