|
45 | 45 | import lombok.experimental.FieldDefaults;
|
46 | 46 | import lombok.extern.log4j.Log4j2;
|
47 | 47 |
|
| 48 | +import java.io.IOException; |
| 49 | + |
48 | 50 | @Log4j2
|
49 | 51 | @FieldDefaults(level = AccessLevel.PRIVATE)
|
50 | 52 | public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {
|
@@ -160,8 +162,6 @@ public void onResponse(MLModel mlModel) {
|
160 | 162 | executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
|
161 | 163 | }
|
162 | 164 | } else {
|
163 |
| - log.info("ModelId " + modelId + " is enabled for prediction in remote layer."); |
164 |
| - log.info(mlPredictionTaskRequest.getMlInput().toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()); |
165 | 165 | validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
|
166 | 166 | executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
|
167 | 167 | }
|
@@ -236,21 +236,19 @@ private void executePredict(
|
236 | 236 | );
|
237 | 237 | }
|
238 | 238 |
|
239 |
| - private void validateInputSchema(String modelId, MLInput mlInput) { |
240 |
| - log.info("Went to validate input schema"); |
241 |
| - log.info(modelCacheHelper.getModelInterface(modelId).containsKey("input")); |
242 |
| - log.info(modelCacheHelper.getModelInterface(modelId).get("input")); |
| 239 | + private void validateInputSchema(String modelId, MLInput mlInput) throws IOException { |
243 | 240 | if (modelCacheHelper.getModelInterface(modelId) != null && modelCacheHelper.getModelInterface(modelId).get("input") != null) {
|
244 | 241 | String inputSchemaString = modelCacheHelper.getModelInterface(modelId).get("input");
|
245 | 242 | log.info(inputSchemaString);
|
| 243 | + log.info(mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()); |
246 | 244 | try {
|
247 | 245 | MLNodeUtils
|
248 | 246 | .validateSchema(
|
249 | 247 | inputSchemaString,
|
250 | 248 | mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()
|
251 | 249 | );
|
252 | 250 | } catch (Exception e) {
|
253 |
| - throw new IllegalArgumentException("Error validating input schema: " + e.getMessage()); |
| 251 | + throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST); |
254 | 252 | }
|
255 | 253 | }
|
256 | 254 | }
|
|
0 commit comments