11
11
import static org .opensearch .ml .processor .InferenceProcessorAttributes .MODEL_CONFIG ;
12
12
import static org .opensearch .ml .processor .InferenceProcessorAttributes .MODEL_ID ;
13
13
import static org .opensearch .ml .processor .InferenceProcessorAttributes .OUTPUT_MAP ;
14
+ import static org .opensearch .ml .processor .ModelExecutor .combineMaps ;
14
15
15
16
import java .io .IOException ;
16
17
import java .util .Collection ;
30
31
import org .opensearch .common .xcontent .LoggingDeprecationHandler ;
31
32
import org .opensearch .common .xcontent .XContentType ;
32
33
import org .opensearch .core .action .ActionListener ;
34
+ import org .opensearch .core .common .util .CollectionUtils ;
33
35
import org .opensearch .core .xcontent .NamedXContentRegistry ;
34
36
import org .opensearch .core .xcontent .XContentParser ;
35
37
import org .opensearch .ingest .ConfigurationUtils ;
50
52
import com .jayway .jsonpath .PathNotFoundException ;
51
53
import com .jayway .jsonpath .ReadContext ;
52
54
55
+ import lombok .Getter ;
56
+
53
57
/**
54
58
* MLInferenceSearchRequestProcessor requires a modelId string to call model inferences
55
59
* maps fields from query string for model input, and maps model inference output to the query strings or query template
58
62
public class MLInferenceSearchRequestProcessor extends AbstractProcessor implements SearchRequestProcessor , ModelExecutor {
59
63
private final NamedXContentRegistry xContentRegistry ;
60
64
private static final Logger logger = LogManager .getLogger (MLInferenceSearchRequestProcessor .class );
65
+ @ Getter
61
66
private final InferenceProcessorAttributes inferenceProcessorAttributes ;
62
67
private final boolean ignoreMissing ;
63
68
private final String functionName ;
69
+ @ Getter
70
+ private final List <Map <String , String >> optionalInputMaps ;
71
+ @ Getter
72
+ private final List <Map <String , String >> optionalOutputMaps ;
64
73
private String queryTemplate ;
65
74
private final boolean fullResponsePath ;
66
75
private final boolean ignoreFailure ;
@@ -78,12 +87,16 @@ public class MLInferenceSearchRequestProcessor extends AbstractProcessor impleme
78
87
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
79
88
// it can be overwritten using max_prediction_tasks when creating processor
80
89
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10 ;
90
+ public static final String OPTIONAL_INPUT_MAP = "optional_input_map" ;
91
+ public static final String OPTIONAL_OUTPUT_MAP = "optional_output_map" ;
81
92
82
93
protected MLInferenceSearchRequestProcessor (
83
94
String modelId ,
84
95
String queryTemplate ,
85
96
List <Map <String , String >> inputMaps ,
86
97
List <Map <String , String >> outputMaps ,
98
+ List <Map <String , String >> optionalInputMaps ,
99
+ List <Map <String , String >> optionalOutputMaps ,
87
100
Map <String , String > modelConfigMaps ,
88
101
int maxPredictionTask ,
89
102
String tag ,
@@ -104,6 +117,8 @@ protected MLInferenceSearchRequestProcessor(
104
117
modelConfigMaps ,
105
118
maxPredictionTask
106
119
);
120
+ this .optionalInputMaps = optionalInputMaps ;
121
+ this .optionalOutputMaps = optionalOutputMaps ;
107
122
this .ignoreMissing = ignoreMissing ;
108
123
this .functionName = functionName ;
109
124
this .fullResponsePath = fullResponsePath ;
@@ -179,19 +194,25 @@ private void rewriteQueryString(
179
194
) throws IOException {
180
195
List <Map <String , String >> processInputMap = inferenceProcessorAttributes .getInputMaps ();
181
196
List <Map <String , String >> processOutputMap = inferenceProcessorAttributes .getOutputMaps ();
182
- int inputMapSize = (processInputMap != null ) ? processInputMap .size () : 0 ;
183
197
184
- if (inputMapSize == 0 ) {
198
+ // Combine processInputMap and optionalInputMaps
199
+ List <Map <String , String >> combinedInputMaps = combineMaps (processInputMap , optionalInputMaps );
200
+ // Combine processOutputMap and optionalOutputMaps
201
+ List <Map <String , String >> combinedOutputMaps = combineMaps (processOutputMap , optionalOutputMaps );
202
+
203
+ int combinedInputMapSize = (combinedInputMaps != null ) ? combinedInputMaps .size () : 0 ;
204
+
205
+ if (combinedInputMapSize == 0 ) {
185
206
requestListener .onResponse (request );
186
207
return ;
187
208
}
188
209
189
210
try {
190
- if (!validateQueryFieldInQueryString (processInputMap , processOutputMap , queryString )) {
211
+ if (!validateQueryFieldInQueryString (processInputMap , processOutputMap , queryString , ignoreMissing )) {
191
212
requestListener .onResponse (request );
192
213
}
193
214
} catch (Exception e ) {
194
- if (ignoreMissing ) {
215
+ if (ignoreFailure ) {
195
216
requestListener .onResponse (request );
196
217
return ;
197
218
} else {
@@ -204,16 +225,16 @@ private void rewriteQueryString(
204
225
request ,
205
226
queryString ,
206
227
requestListener ,
207
- processOutputMap ,
228
+ combinedOutputMaps ,
208
229
requestContext
209
230
);
210
231
GroupedActionListener <Map <Integer , MLOutput >> batchPredictionListener = createBatchPredictionListener (
211
232
rewriteRequestListener ,
212
- inputMapSize
233
+ combinedInputMapSize
213
234
);
214
235
215
- for (int inputMapIndex = 0 ; inputMapIndex < inputMapSize ; inputMapIndex ++) {
216
- processPredictions (queryString , processInputMap , inputMapIndex , batchPredictionListener );
236
+ for (int inputMapIndex = 0 ; inputMapIndex < combinedInputMapSize ; inputMapIndex ++) {
237
+ processPredictions (queryString , combinedInputMaps , inputMapIndex , batchPredictionListener );
217
238
}
218
239
219
240
}
@@ -376,33 +397,45 @@ public void onFailure(Exception e) {
376
397
}, Math .max (inputMapSize , 1 ));
377
398
}
378
399
379
- /**
380
- * Validates that the query fields specified in the input and output mappings exist in the query string.
381
- * @param processInputMap the list of input mappings
382
- * @param processOutputMap the list of output mappings
383
- * @param queryString the query string to be validated
384
- * @return true if all query fields exist in the query string, false otherwise
385
- */
386
- private boolean validateQueryFieldInQueryString (
400
+ private boolean validateRequiredInputMappingFields (
387
401
List <Map <String , String >> processInputMap ,
388
- List < Map < String , String >> processOutputMap ,
389
- String queryString
402
+ String queryString ,
403
+ boolean ignoreMissing
390
404
) {
391
405
// Suppress errors thrown by JsonPath and instead return null if a path does not exist in a JSON blob.
392
406
Configuration suppressExceptionConfiguration = Configuration .defaultConfiguration ().addOptions (Option .SUPPRESS_EXCEPTIONS );
393
407
ReadContext jsonData = JsonPath .using (suppressExceptionConfiguration ).parse (queryString );
394
408
395
- // check all values if exists in query
409
+ // check all values if exists in query for required fields
396
410
for (Map <String , String > inputMap : processInputMap ) {
397
411
for (Map .Entry <String , String > entry : inputMap .entrySet ()) {
398
412
// the inputMap takes in model input as keys and query fields as value
399
413
String queryField = entry .getValue ();
400
414
Object pathData = jsonData .read (queryField );
401
415
if (pathData == null ) {
402
- throw new IllegalArgumentException ("cannot find field: " + queryField + " in query string: " + jsonData .jsonString ());
416
+ if (!ignoreMissing ) {
417
+ throw new IllegalArgumentException (
418
+ "cannot find field: " + queryField + " in query string: " + jsonData .jsonString ()
419
+ );
420
+ } else {
421
+ return false ;
422
+ }
403
423
}
404
424
}
405
425
}
426
+ return true ;
427
+ }
428
+
429
+ private boolean validateRequiredOutputMappingFields (
430
+ List <Map <String , String >> processOutputMap ,
431
+ String queryString ,
432
+ boolean ignoreMissing
433
+ ) {
434
+ // Suppress errors thrown by JsonPath and instead return null if a path does not exist in a JSON blob.
435
+ Configuration suppressExceptionConfiguration = Configuration .defaultConfiguration ().addOptions (Option .SUPPRESS_EXCEPTIONS );
436
+ ReadContext jsonData = JsonPath .using (suppressExceptionConfiguration ).parse (queryString );
437
+
438
+ // check all values if exists in query for required fields
406
439
if (queryTemplate == null ) {
407
440
for (Map <String , String > outputMap : processOutputMap ) {
408
441
for (Map .Entry <String , String > entry : outputMap .entrySet ()) {
@@ -411,25 +444,57 @@ private boolean validateQueryFieldInQueryString(
411
444
if (queryField .startsWith ("query." ) || queryField .startsWith ("$.query." )) {
412
445
Object pathData = jsonData .read (queryField );
413
446
if (pathData == null ) {
414
- throw new IllegalArgumentException (
415
- "cannot find field: " + queryField + " in query string: " + jsonData .jsonString ()
416
- );
447
+ if (!ignoreMissing ) {
448
+ throw new IllegalArgumentException (
449
+ "cannot find field: " + queryField + " in query string: " + jsonData .jsonString ()
450
+ );
451
+ } else {
452
+ return false ;
453
+ }
417
454
}
418
455
}
419
456
}
420
457
}
421
458
}
422
-
423
459
return true ;
460
+ }
461
+
462
+ /**
463
+ * Validates that the query fields specified in the input and output mappings exist in the query string.
464
+ *
465
+ * @param processInputMap the list of input mappings
466
+ * @param processOutputMap the list of output mappings
467
+ * @param queryString the query string to be validated
468
+ * @param ignoreMissing
469
+ * @return true if all query fields exist in the query string, false otherwise
470
+ */
471
+ private boolean validateQueryFieldInQueryString (
472
+ List <Map <String , String >> processInputMap ,
473
+ List <Map <String , String >> processOutputMap ,
474
+ String queryString ,
475
+ boolean ignoreMissing
476
+ ) {
477
+ if (!CollectionUtils .isEmpty (processInputMap )) {
478
+ if (!validateRequiredInputMappingFields (processInputMap , queryString , ignoreMissing )) {
479
+ return false ;
480
+ }
481
+ }
424
482
483
+ if (!CollectionUtils .isEmpty (processOutputMap )) {
484
+ if (!validateRequiredOutputMappingFields (processOutputMap , queryString , ignoreMissing )) {
485
+ return false ;
486
+ }
487
+ }
488
+
489
+ return true ;
425
490
}
426
491
427
492
/**
428
493
* Processes the ML model inference for a given input mapping index.
429
494
*
430
- * @param queryString the original query string
431
- * @param processInputMap the list of input mappings
432
- * @param inputMapIndex the index of the input mapping to be processed
495
+ * @param queryString the original query string
496
+ * @param processInputMap the list of input mappings
497
+ * @param inputMapIndex the index of the input mapping to be processed
433
498
* @param batchPredictionListener the {@link GroupedActionListener} to be notified when the ML model inference is complete
434
499
* @throws IOException if an I/O error occurs during the processing
435
500
*/
@@ -455,8 +520,10 @@ private void processPredictions(
455
520
// model field as key, query field name as value
456
521
String modelInputFieldName = entry .getKey ();
457
522
String queryFieldName = entry .getValue ();
458
- String queryFieldValue = toJson (JsonPath .parse (newQuery ).read (queryFieldName ));
459
- modelParameters .put (modelInputFieldName , queryFieldValue );
523
+ if (hasField (newQuery , queryFieldName )) {
524
+ String queryFieldValue = toJson (JsonPath .parse (newQuery ).read (queryFieldName ));
525
+ modelParameters .put (modelInputFieldName , queryFieldValue );
526
+ }
460
527
}
461
528
}
462
529
@@ -577,8 +644,26 @@ public MLInferenceSearchRequestProcessor create(
577
644
String queryTemplate = ConfigurationUtils .readOptionalStringProperty (TYPE , processorTag , config , QUERY_TEMPLATE );
578
645
Map <String , Object > modelConfigInput = ConfigurationUtils .readOptionalMap (TYPE , processorTag , config , MODEL_CONFIG );
579
646
580
- List <Map <String , String >> inputMaps = ConfigurationUtils .readList (TYPE , processorTag , config , INPUT_MAP );
581
- List <Map <String , String >> outputMaps = ConfigurationUtils .readList (TYPE , processorTag , config , OUTPUT_MAP );
647
+ List <Map <String , String >> inputMaps = ConfigurationUtils .readOptionalList (TYPE , processorTag , config , INPUT_MAP );
648
+ List <Map <String , String >> outputMaps = ConfigurationUtils .readOptionalList (TYPE , processorTag , config , OUTPUT_MAP );
649
+
650
+ List <Map <String , String >> optionalInputMaps = ConfigurationUtils
651
+ .readOptionalList (TYPE , processorTag , config , OPTIONAL_INPUT_MAP );
652
+ List <Map <String , String >> optionalOutputMaps = ConfigurationUtils
653
+ .readOptionalList (TYPE , processorTag , config , OPTIONAL_OUTPUT_MAP );
654
+
655
+ if (CollectionUtils .isEmpty (inputMaps ) && CollectionUtils .isEmpty (optionalInputMaps )) {
656
+ throw new IllegalArgumentException (
657
+ "Please provide at least one non-empty input_map or optional_input_map for ML Inference Search Request Processor"
658
+ );
659
+ }
660
+
661
+ if (CollectionUtils .isEmpty (outputMaps ) && CollectionUtils .isEmpty (optionalOutputMaps )) {
662
+ throw new IllegalArgumentException (
663
+ "Please provide at least one non-empty output_map or optional_output_map for ML Inference Search Request Processor"
664
+ );
665
+ }
666
+
582
667
int maxPredictionTask = ConfigurationUtils
583
668
.readIntProperty (TYPE , processorTag , config , MAX_PREDICTION_TASKS , DEFAULT_MAX_PREDICTION_TASKS );
584
669
boolean ignoreMissing = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , IGNORE_MISSING , false );
@@ -602,14 +687,28 @@ public MLInferenceSearchRequestProcessor create(
602
687
if (modelConfigInput != null ) {
603
688
modelConfigMaps = StringUtils .getParameterMap (modelConfigInput );
604
689
}
690
+ // Combine processInputMap and optionalInputMaps
691
+ List <Map <String , String >> combinedInputMaps = ModelExecutor .combineMaps (inputMaps , optionalInputMaps );
692
+ // Combine processOutputMap and optionalOutputMaps
693
+ List <Map <String , String >> combinedOutputMaps = ModelExecutor .combineMaps (outputMaps , optionalOutputMaps );
694
+
605
695
// check if the number of prediction tasks exceeds max prediction tasks
606
- if (inputMaps != null && inputMaps .size () > maxPredictionTask ) {
696
+ if (combinedInputMaps != null && combinedInputMaps .size () > maxPredictionTask ) {
607
697
throw new IllegalArgumentException (
608
698
"The number of prediction task setting in this process is "
609
- + inputMaps .size ()
699
+ + combinedInputMaps .size ()
610
700
+ ". It exceeds the max_prediction_tasks of "
611
701
+ maxPredictionTask
612
- + ". Please reduce the size of input_map or increase max_prediction_tasks."
702
+ + ". Please reduce the size of input_map or optional_input_map or increase max_prediction_tasks."
703
+ );
704
+ }
705
+ if (combinedOutputMaps != null && combinedInputMaps != null && combinedOutputMaps .size () != combinedInputMaps .size ()) {
706
+ throw new IllegalArgumentException (
707
+ "when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of "
708
+ + combinedInputMaps .size ()
709
+ + ", while output_maps is in the length of "
710
+ + combinedInputMaps .size ()
711
+ + ". Please adjust mappings."
613
712
);
614
713
}
615
714
@@ -618,6 +717,8 @@ public MLInferenceSearchRequestProcessor create(
618
717
queryTemplate ,
619
718
inputMaps ,
620
719
outputMaps ,
720
+ optionalInputMaps ,
721
+ optionalOutputMaps ,
621
722
modelConfigMaps ,
622
723
maxPredictionTask ,
623
724
processorTag ,
0 commit comments