Skip to content

Commit b09b8f6

Browse files
[Backport 2.x] fix optional mappings in ml inference search processors (#3595)
* fix optional mappings in ml inference search processors (#3587) * fix optional mappings Signed-off-by: Mingshi Liu <mingshl@amazon.com> * use collections and add more assertion tests Signed-off-by: Mingshi Liu <mingshl@amazon.com> * validate query return false Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com> (cherry picked from commit b22e61a) * fix flaky test (#3598) Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com> Co-authored-by: Mingshi Liu <mingshl@amazon.com>
1 parent 88ad8cf commit b09b8f6

7 files changed

+1934
-155
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java

+135-34
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG;
1212
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID;
1313
import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP;
14+
import static org.opensearch.ml.processor.ModelExecutor.combineMaps;
1415

1516
import java.io.IOException;
1617
import java.util.Collection;
@@ -30,6 +31,7 @@
3031
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
3132
import org.opensearch.common.xcontent.XContentType;
3233
import org.opensearch.core.action.ActionListener;
34+
import org.opensearch.core.common.util.CollectionUtils;
3335
import org.opensearch.core.xcontent.NamedXContentRegistry;
3436
import org.opensearch.core.xcontent.XContentParser;
3537
import org.opensearch.ingest.ConfigurationUtils;
@@ -50,6 +52,8 @@
5052
import com.jayway.jsonpath.PathNotFoundException;
5153
import com.jayway.jsonpath.ReadContext;
5254

55+
import lombok.Getter;
56+
5357
/**
5458
* MLInferenceSearchRequestProcessor requires a modelId string to call model inferences
5559
* maps fields from query string for model input, and maps model inference output to the query strings or query template
@@ -58,9 +62,14 @@
5862
public class MLInferenceSearchRequestProcessor extends AbstractProcessor implements SearchRequestProcessor, ModelExecutor {
5963
private final NamedXContentRegistry xContentRegistry;
6064
private static final Logger logger = LogManager.getLogger(MLInferenceSearchRequestProcessor.class);
65+
@Getter
6166
private final InferenceProcessorAttributes inferenceProcessorAttributes;
6267
private final boolean ignoreMissing;
6368
private final String functionName;
69+
@Getter
70+
private final List<Map<String, String>> optionalInputMaps;
71+
@Getter
72+
private final List<Map<String, String>> optionalOutputMaps;
6473
private String queryTemplate;
6574
private final boolean fullResponsePath;
6675
private final boolean ignoreFailure;
@@ -78,12 +87,16 @@ public class MLInferenceSearchRequestProcessor extends AbstractProcessor impleme
7887
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
7988
// it can be overwritten using max_prediction_tasks when creating processor
8089
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
90+
public static final String OPTIONAL_INPUT_MAP = "optional_input_map";
91+
public static final String OPTIONAL_OUTPUT_MAP = "optional_output_map";
8192

8293
protected MLInferenceSearchRequestProcessor(
8394
String modelId,
8495
String queryTemplate,
8596
List<Map<String, String>> inputMaps,
8697
List<Map<String, String>> outputMaps,
98+
List<Map<String, String>> optionalInputMaps,
99+
List<Map<String, String>> optionalOutputMaps,
87100
Map<String, String> modelConfigMaps,
88101
int maxPredictionTask,
89102
String tag,
@@ -104,6 +117,8 @@ protected MLInferenceSearchRequestProcessor(
104117
modelConfigMaps,
105118
maxPredictionTask
106119
);
120+
this.optionalInputMaps = optionalInputMaps;
121+
this.optionalOutputMaps = optionalOutputMaps;
107122
this.ignoreMissing = ignoreMissing;
108123
this.functionName = functionName;
109124
this.fullResponsePath = fullResponsePath;
@@ -179,19 +194,25 @@ private void rewriteQueryString(
179194
) throws IOException {
180195
List<Map<String, String>> processInputMap = inferenceProcessorAttributes.getInputMaps();
181196
List<Map<String, String>> processOutputMap = inferenceProcessorAttributes.getOutputMaps();
182-
int inputMapSize = (processInputMap != null) ? processInputMap.size() : 0;
183197

184-
if (inputMapSize == 0) {
198+
// Combine processInputMap and optionalInputMaps
199+
List<Map<String, String>> combinedInputMaps = combineMaps(processInputMap, optionalInputMaps);
200+
// Combine processOutputMap and optionalOutputMaps
201+
List<Map<String, String>> combinedOutputMaps = combineMaps(processOutputMap, optionalOutputMaps);
202+
203+
int combinedInputMapSize = (combinedInputMaps != null) ? combinedInputMaps.size() : 0;
204+
205+
if (combinedInputMapSize == 0) {
185206
requestListener.onResponse(request);
186207
return;
187208
}
188209

189210
try {
190-
if (!validateQueryFieldInQueryString(processInputMap, processOutputMap, queryString)) {
211+
if (!validateQueryFieldInQueryString(processInputMap, processOutputMap, queryString, ignoreMissing)) {
191212
requestListener.onResponse(request);
192213
}
193214
} catch (Exception e) {
194-
if (ignoreMissing) {
215+
if (ignoreFailure) {
195216
requestListener.onResponse(request);
196217
return;
197218
} else {
@@ -204,16 +225,16 @@ private void rewriteQueryString(
204225
request,
205226
queryString,
206227
requestListener,
207-
processOutputMap,
228+
combinedOutputMaps,
208229
requestContext
209230
);
210231
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListener(
211232
rewriteRequestListener,
212-
inputMapSize
233+
combinedInputMapSize
213234
);
214235

215-
for (int inputMapIndex = 0; inputMapIndex < inputMapSize; inputMapIndex++) {
216-
processPredictions(queryString, processInputMap, inputMapIndex, batchPredictionListener);
236+
for (int inputMapIndex = 0; inputMapIndex < combinedInputMapSize; inputMapIndex++) {
237+
processPredictions(queryString, combinedInputMaps, inputMapIndex, batchPredictionListener);
217238
}
218239

219240
}
@@ -376,33 +397,45 @@ public void onFailure(Exception e) {
376397
}, Math.max(inputMapSize, 1));
377398
}
378399

379-
/**
380-
* Validates that the query fields specified in the input and output mappings exist in the query string.
381-
* @param processInputMap the list of input mappings
382-
* @param processOutputMap the list of output mappings
383-
* @param queryString the query string to be validated
384-
* @return true if all query fields exist in the query string, false otherwise
385-
*/
386-
private boolean validateQueryFieldInQueryString(
400+
private boolean validateRequiredInputMappingFields(
387401
List<Map<String, String>> processInputMap,
388-
List<Map<String, String>> processOutputMap,
389-
String queryString
402+
String queryString,
403+
boolean ignoreMissing
390404
) {
391405
// Suppress errors thrown by JsonPath and instead return null if a path does not exist in a JSON blob.
392406
Configuration suppressExceptionConfiguration = Configuration.defaultConfiguration().addOptions(Option.SUPPRESS_EXCEPTIONS);
393407
ReadContext jsonData = JsonPath.using(suppressExceptionConfiguration).parse(queryString);
394408

395-
// check all values if exists in query
409+
// check all values if exists in query for required fields
396410
for (Map<String, String> inputMap : processInputMap) {
397411
for (Map.Entry<String, String> entry : inputMap.entrySet()) {
398412
// the inputMap takes in model input as keys and query fields as value
399413
String queryField = entry.getValue();
400414
Object pathData = jsonData.read(queryField);
401415
if (pathData == null) {
402-
throw new IllegalArgumentException("cannot find field: " + queryField + " in query string: " + jsonData.jsonString());
416+
if (!ignoreMissing) {
417+
throw new IllegalArgumentException(
418+
"cannot find field: " + queryField + " in query string: " + jsonData.jsonString()
419+
);
420+
} else {
421+
return false;
422+
}
403423
}
404424
}
405425
}
426+
return true;
427+
}
428+
429+
private boolean validateRequiredOutputMappingFields(
430+
List<Map<String, String>> processOutputMap,
431+
String queryString,
432+
boolean ignoreMissing
433+
) {
434+
// Suppress errors thrown by JsonPath and instead return null if a path does not exist in a JSON blob.
435+
Configuration suppressExceptionConfiguration = Configuration.defaultConfiguration().addOptions(Option.SUPPRESS_EXCEPTIONS);
436+
ReadContext jsonData = JsonPath.using(suppressExceptionConfiguration).parse(queryString);
437+
438+
// check all values if exists in query for required fields
406439
if (queryTemplate == null) {
407440
for (Map<String, String> outputMap : processOutputMap) {
408441
for (Map.Entry<String, String> entry : outputMap.entrySet()) {
@@ -411,25 +444,57 @@ private boolean validateQueryFieldInQueryString(
411444
if (queryField.startsWith("query.") || queryField.startsWith("$.query.")) {
412445
Object pathData = jsonData.read(queryField);
413446
if (pathData == null) {
414-
throw new IllegalArgumentException(
415-
"cannot find field: " + queryField + " in query string: " + jsonData.jsonString()
416-
);
447+
if (!ignoreMissing) {
448+
throw new IllegalArgumentException(
449+
"cannot find field: " + queryField + " in query string: " + jsonData.jsonString()
450+
);
451+
} else {
452+
return false;
453+
}
417454
}
418455
}
419456
}
420457
}
421458
}
422-
423459
return true;
460+
}
461+
462+
/**
463+
* Validates that the query fields specified in the input and output mappings exist in the query string.
464+
*
465+
* @param processInputMap the list of input mappings
466+
* @param processOutputMap the list of output mappings
467+
* @param queryString the query string to be validated
468+
* @param ignoreMissing
469+
* @return true if all query fields exist in the query string, false otherwise
470+
*/
471+
private boolean validateQueryFieldInQueryString(
472+
List<Map<String, String>> processInputMap,
473+
List<Map<String, String>> processOutputMap,
474+
String queryString,
475+
boolean ignoreMissing
476+
) {
477+
if (!CollectionUtils.isEmpty(processInputMap)) {
478+
if (!validateRequiredInputMappingFields(processInputMap, queryString, ignoreMissing)) {
479+
return false;
480+
}
481+
}
424482

483+
if (!CollectionUtils.isEmpty(processOutputMap)) {
484+
if (!validateRequiredOutputMappingFields(processOutputMap, queryString, ignoreMissing)) {
485+
return false;
486+
}
487+
}
488+
489+
return true;
425490
}
426491

427492
/**
428493
* Processes the ML model inference for a given input mapping index.
429494
*
430-
* @param queryString the original query string
431-
* @param processInputMap the list of input mappings
432-
* @param inputMapIndex the index of the input mapping to be processed
495+
* @param queryString the original query string
496+
* @param processInputMap the list of input mappings
497+
* @param inputMapIndex the index of the input mapping to be processed
433498
* @param batchPredictionListener the {@link GroupedActionListener} to be notified when the ML model inference is complete
434499
* @throws IOException if an I/O error occurs during the processing
435500
*/
@@ -455,8 +520,10 @@ private void processPredictions(
455520
// model field as key, query field name as value
456521
String modelInputFieldName = entry.getKey();
457522
String queryFieldName = entry.getValue();
458-
String queryFieldValue = toJson(JsonPath.parse(newQuery).read(queryFieldName));
459-
modelParameters.put(modelInputFieldName, queryFieldValue);
523+
if (hasField(newQuery, queryFieldName)) {
524+
String queryFieldValue = toJson(JsonPath.parse(newQuery).read(queryFieldName));
525+
modelParameters.put(modelInputFieldName, queryFieldValue);
526+
}
460527
}
461528
}
462529

@@ -577,8 +644,26 @@ public MLInferenceSearchRequestProcessor create(
577644
String queryTemplate = ConfigurationUtils.readOptionalStringProperty(TYPE, processorTag, config, QUERY_TEMPLATE);
578645
Map<String, Object> modelConfigInput = ConfigurationUtils.readOptionalMap(TYPE, processorTag, config, MODEL_CONFIG);
579646

580-
List<Map<String, String>> inputMaps = ConfigurationUtils.readList(TYPE, processorTag, config, INPUT_MAP);
581-
List<Map<String, String>> outputMaps = ConfigurationUtils.readList(TYPE, processorTag, config, OUTPUT_MAP);
647+
List<Map<String, String>> inputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, INPUT_MAP);
648+
List<Map<String, String>> outputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, OUTPUT_MAP);
649+
650+
List<Map<String, String>> optionalInputMaps = ConfigurationUtils
651+
.readOptionalList(TYPE, processorTag, config, OPTIONAL_INPUT_MAP);
652+
List<Map<String, String>> optionalOutputMaps = ConfigurationUtils
653+
.readOptionalList(TYPE, processorTag, config, OPTIONAL_OUTPUT_MAP);
654+
655+
if (CollectionUtils.isEmpty(inputMaps) && CollectionUtils.isEmpty(optionalInputMaps)) {
656+
throw new IllegalArgumentException(
657+
"Please provide at least one non-empty input_map or optional_input_map for ML Inference Search Request Processor"
658+
);
659+
}
660+
661+
if (CollectionUtils.isEmpty(outputMaps) && CollectionUtils.isEmpty(optionalOutputMaps)) {
662+
throw new IllegalArgumentException(
663+
"Please provide at least one non-empty output_map or optional_output_map for ML Inference Search Request Processor"
664+
);
665+
}
666+
582667
int maxPredictionTask = ConfigurationUtils
583668
.readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS);
584669
boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false);
@@ -602,14 +687,28 @@ public MLInferenceSearchRequestProcessor create(
602687
if (modelConfigInput != null) {
603688
modelConfigMaps = StringUtils.getParameterMap(modelConfigInput);
604689
}
690+
// Combine processInputMap and optionalInputMaps
691+
List<Map<String, String>> combinedInputMaps = ModelExecutor.combineMaps(inputMaps, optionalInputMaps);
692+
// Combine processOutputMap and optionalOutputMaps
693+
List<Map<String, String>> combinedOutputMaps = ModelExecutor.combineMaps(outputMaps, optionalOutputMaps);
694+
605695
// check if the number of prediction tasks exceeds max prediction tasks
606-
if (inputMaps != null && inputMaps.size() > maxPredictionTask) {
696+
if (combinedInputMaps != null && combinedInputMaps.size() > maxPredictionTask) {
607697
throw new IllegalArgumentException(
608698
"The number of prediction task setting in this process is "
609-
+ inputMaps.size()
699+
+ combinedInputMaps.size()
610700
+ ". It exceeds the max_prediction_tasks of "
611701
+ maxPredictionTask
612-
+ ". Please reduce the size of input_map or increase max_prediction_tasks."
702+
+ ". Please reduce the size of input_map or optional_input_map or increase max_prediction_tasks."
703+
);
704+
}
705+
if (combinedOutputMaps != null && combinedInputMaps != null && combinedOutputMaps.size() != combinedInputMaps.size()) {
706+
throw new IllegalArgumentException(
707+
"when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of "
708+
+ combinedInputMaps.size()
709+
+ ", while output_maps is in the length of "
710+
+ combinedInputMaps.size()
711+
+ ". Please adjust mappings."
613712
);
614713
}
615714

@@ -618,6 +717,8 @@ public MLInferenceSearchRequestProcessor create(
618717
queryTemplate,
619718
inputMaps,
620719
outputMaps,
720+
optionalInputMaps,
721+
optionalOutputMaps,
621722
modelConfigMaps,
622723
maxPredictionTask,
623724
processorTag,

0 commit comments

Comments
 (0)