Skip to content

Commit 4d7d36f

Browse files
committed
validate input
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 5cbb166 commit 4d7d36f

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.apache.commons.text.StringSubstitutor;
2828
import org.opensearch.OpenSearchParseException;
2929
import org.opensearch.OpenSearchStatusException;
30+
import org.opensearch.core.xcontent.XContent;
31+
import org.opensearch.core.xcontent.XContentBuilder;
3032
import org.opensearch.ml.common.connector.Connector;
3133
import org.opensearch.ml.common.connector.ConnectorAction;
3234
import org.opensearch.ml.common.connector.MLPostProcessFunction;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
99
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
10+
11+
import java.io.IOException;
1012
import java.util.ArrayList;
1113
import java.util.HashMap;
1214
import java.util.List;
@@ -16,10 +18,13 @@
1618
import org.opensearch.client.Client;
1719
import org.opensearch.cluster.service.ClusterService;
1820
import org.opensearch.common.util.TokenBucket;
21+
import org.opensearch.common.xcontent.XContentFactory;
1922
import org.opensearch.commons.ConfigConstants;
2023
import org.opensearch.commons.authuser.User;
2124
import org.opensearch.core.rest.RestStatus;
2225
import org.opensearch.core.xcontent.NamedXContentRegistry;
26+
import org.opensearch.core.xcontent.ToXContent;
27+
import org.opensearch.core.xcontent.XContentBuilder;
2328
import org.opensearch.ml.common.FunctionName;
2429
import org.opensearch.ml.common.connector.Connector;
2530
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -33,7 +38,7 @@
3338

3439
public interface RemoteConnectorExecutor {
3540

36-
default ModelTensorOutput executePredict(MLInput mlInput) {
41+
default ModelTensorOutput executePredict(MLInput mlInput) throws IOException {
3742
List<ModelTensors> tensorOutputs = new ArrayList<>();
3843
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
3944
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
@@ -151,10 +156,10 @@ && getUserRateLimiterMap().get(user.getName()) != null
151156

152157
void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs);
153158

154-
private void validateInputSchema(MLInput mlInput) {
159+
private void validateInputSchema(MLInput mlInput) throws IOException {
155160
if (getConnector().getModelInterface() != null && getConnector().getModelInterface().get("input") != null) {
156161
String schemaString = getConnector().getModelInterface().get("input");
157-
ConnectorUtils.validateSchema(schemaString, mlInput.toString());
162+
ConnectorUtils.validateSchema(schemaString, mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
158163
}
159164
}
160165

0 commit comments

Comments
 (0)