|
7 | 7 |
|
8 | 8 | import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
|
9 | 9 | import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
|
| 10 | + |
| 11 | +import java.io.IOException; |
10 | 12 | import java.util.ArrayList;
|
11 | 13 | import java.util.HashMap;
|
12 | 14 | import java.util.List;
|
|
16 | 18 | import org.opensearch.client.Client;
|
17 | 19 | import org.opensearch.cluster.service.ClusterService;
|
18 | 20 | import org.opensearch.common.util.TokenBucket;
|
| 21 | +import org.opensearch.common.xcontent.XContentFactory; |
19 | 22 | import org.opensearch.commons.ConfigConstants;
|
20 | 23 | import org.opensearch.commons.authuser.User;
|
21 | 24 | import org.opensearch.core.rest.RestStatus;
|
22 | 25 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
| 26 | +import org.opensearch.core.xcontent.ToXContent; |
| 27 | +import org.opensearch.core.xcontent.XContentBuilder; |
23 | 28 | import org.opensearch.ml.common.FunctionName;
|
24 | 29 | import org.opensearch.ml.common.connector.Connector;
|
25 | 30 | import org.opensearch.ml.common.dataset.MLInputDataset;
|
|
33 | 38 |
|
34 | 39 | public interface RemoteConnectorExecutor {
|
35 | 40 |
|
36 |
| - default ModelTensorOutput executePredict(MLInput mlInput) { |
| 41 | + default ModelTensorOutput executePredict(MLInput mlInput) throws IOException { |
37 | 42 | List<ModelTensors> tensorOutputs = new ArrayList<>();
|
38 | 43 | if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
|
39 | 44 | TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
|
@@ -151,10 +156,10 @@ && getUserRateLimiterMap().get(user.getName()) != null
|
151 | 156 |
|
152 | 157 | void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs);
|
153 | 158 |
|
154 |
| - private void validateInputSchema(MLInput mlInput) { |
| 159 | + private void validateInputSchema(MLInput mlInput) throws IOException { |
155 | 160 | if (getConnector().getModelInterface() != null && getConnector().getModelInterface().get("input") != null) {
|
156 | 161 | String schemaString = getConnector().getModelInterface().get("input");
|
157 |
| - ConnectorUtils.validateSchema(schemaString, mlInput.toString()); |
| 162 | + ConnectorUtils.validateSchema(schemaString, mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()); |
158 | 163 | }
|
159 | 164 | }
|
160 | 165 |
|
|
0 commit comments