6
6
7
7
import static org .opensearch .ml .processor .InferenceProcessorAttributes .*;
8
8
9
+ import java .io .IOException ;
9
10
import java .util .ArrayList ;
10
11
import java .util .Collection ;
11
12
import java .util .HashMap ;
13
+ import java .util .HashSet ;
12
14
import java .util .List ;
13
15
import java .util .Map ;
14
16
import java .util .Set ;
19
21
import org .opensearch .client .Client ;
20
22
import org .opensearch .core .action .ActionListener ;
21
23
import org .opensearch .core .common .Strings ;
24
+ import org .opensearch .core .xcontent .NamedXContentRegistry ;
22
25
import org .opensearch .ingest .AbstractProcessor ;
23
26
import org .opensearch .ingest .ConfigurationUtils ;
24
27
import org .opensearch .ingest .IngestDocument ;
25
28
import org .opensearch .ingest .Processor ;
26
29
import org .opensearch .ingest .ValueSource ;
30
+ import org .opensearch .ml .common .FunctionName ;
31
+ import org .opensearch .ml .common .output .MLOutput ;
27
32
import org .opensearch .ml .common .output .model .ModelTensorOutput ;
28
33
import org .opensearch .ml .common .transport .MLTaskResponse ;
29
34
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskAction ;
@@ -45,17 +50,26 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
45
50
public static final String DOT_SYMBOL = "." ;
46
51
private final InferenceProcessorAttributes inferenceProcessorAttributes ;
47
52
private final boolean ignoreMissing ;
53
+ private final String functionName ;
54
+ private final boolean fullResponsePath ;
48
55
private final boolean ignoreFailure ;
56
+ private final boolean override ;
57
+ private final String modelInput ;
49
58
private final ScriptService scriptService ;
50
59
private static Client client ;
51
60
public static final String TYPE = "ml_inference" ;
52
61
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results" ;
53
62
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
54
63
// prediction outcomes, return the whole prediction outcome by skipping filtering
55
64
public static final String IGNORE_MISSING = "ignore_missing" ;
65
+ public static final String OVERRIDE = "override" ;
66
+ public static final String FUNCTION_NAME = "function_name" ;
67
+ public static final String FULL_RESPONSE_PATH = "full_response_path" ;
68
+ public static final String MODEL_INPUT = "model_input" ;
56
69
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
57
70
// it can be overwritten using max_prediction_tasks when creating processor
58
71
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10 ;
72
+ private final NamedXContentRegistry xContentRegistry ;
59
73
60
74
private Configuration suppressExceptionConfiguration = Configuration
61
75
.builder ()
@@ -71,9 +85,14 @@ protected MLInferenceIngestProcessor(
71
85
String tag ,
72
86
String description ,
73
87
boolean ignoreMissing ,
88
+ String functionName ,
89
+ boolean fullResponsePath ,
74
90
boolean ignoreFailure ,
91
+ boolean override ,
92
+ String modelInput ,
75
93
ScriptService scriptService ,
76
- Client client
94
+ Client client ,
95
+ NamedXContentRegistry xContentRegistry
77
96
) {
78
97
super (tag , description );
79
98
this .inferenceProcessorAttributes = new InferenceProcessorAttributes (
@@ -84,9 +103,14 @@ protected MLInferenceIngestProcessor(
84
103
maxPredictionTask
85
104
);
86
105
this .ignoreMissing = ignoreMissing ;
106
+ this .functionName = functionName ;
107
+ this .fullResponsePath = fullResponsePath ;
87
108
this .ignoreFailure = ignoreFailure ;
109
+ this .override = override ;
110
+ this .modelInput = modelInput ;
88
111
this .scriptService = scriptService ;
89
112
this .client = client ;
113
+ this .xContentRegistry = xContentRegistry ;
90
114
}
91
115
92
116
/**
@@ -162,10 +186,44 @@ private void processPredictions(
162
186
List <Map <String , String >> processOutputMap ,
163
187
int inputMapIndex ,
164
188
int inputMapSize
165
- ) {
189
+ ) throws IOException {
166
190
Map <String , String > modelParameters = new HashMap <>();
191
+ Map <String , String > modelConfigs = new HashMap <>();
192
+
167
193
if (inferenceProcessorAttributes .getModelConfigMaps () != null ) {
168
194
modelParameters .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
195
+ modelConfigs .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
196
+ }
197
+ Map <String , String > outputMapping = processOutputMap .get (inputMapIndex );
198
+
199
+ Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
200
+ ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
201
+ ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
202
+
203
+ Map <String , List <String >> newOutputMapping = new HashMap <>();
204
+ for (Map .Entry <String , String > entry : outputMapping .entrySet ()) {
205
+ String newDocumentFieldName = entry .getKey ();
206
+ List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
207
+ newOutputMapping .put (newDocumentFieldName , dotPathsInArray );
208
+ }
209
+
210
+ for (Map .Entry <String , String > entry : outputMapping .entrySet ()) {
211
+ String newDocumentFieldName = entry .getKey ();
212
+ List <String > dotPaths = newOutputMapping .get (newDocumentFieldName );
213
+
214
+ int existingFields = 0 ;
215
+ for (String path : dotPaths ) {
216
+ if (ingestDocument .hasField (path )) {
217
+ existingFields ++;
218
+ }
219
+ }
220
+ if (!override && existingFields == dotPaths .size ()) {
221
+ newOutputMapping .remove (newDocumentFieldName );
222
+ }
223
+ }
224
+ if (newOutputMapping .size () == 0 ) {
225
+ batchPredictionListener .onResponse (null );
226
+ return ;
169
227
}
170
228
// when no input mapping is provided, default to read all fields from documents as model input
171
229
if (inputMapSize == 0 ) {
@@ -184,15 +242,30 @@ private void processPredictions(
184
242
}
185
243
}
186
244
187
- ActionRequest request = getRemoteModelInferenceRequest (modelParameters , inferenceProcessorAttributes .getModelId ());
245
+ Set <String > inputMapKeys = new HashSet <>(modelParameters .keySet ());
246
+ inputMapKeys .removeAll (modelConfigs .keySet ());
247
+
248
+ Map <String , String > inputMappings = new HashMap <>();
249
+ for (String k : inputMapKeys ) {
250
+ inputMappings .put (k , modelParameters .get (k ));
251
+ }
252
+ ActionRequest request = getRemoteModelInferenceRequest (
253
+ xContentRegistry ,
254
+ modelParameters ,
255
+ modelConfigs ,
256
+ inputMappings ,
257
+ inferenceProcessorAttributes .getModelId (),
258
+ functionName ,
259
+ modelInput
260
+ );
188
261
189
262
client .execute (MLPredictionTaskAction .INSTANCE , request , new ActionListener <>() {
190
263
191
264
@ Override
192
265
public void onResponse (MLTaskResponse mlTaskResponse ) {
193
- ModelTensorOutput modelTensorOutput = ( ModelTensorOutput ) mlTaskResponse .getOutput ();
266
+ MLOutput mlOutput = mlTaskResponse .getOutput ();
194
267
if (processOutputMap == null || processOutputMap .isEmpty ()) {
195
- appendFieldValue (modelTensorOutput , null , DEFAULT_OUTPUT_FIELD_NAME , ingestDocument );
268
+ appendFieldValue (mlOutput , null , DEFAULT_OUTPUT_FIELD_NAME , ingestDocument );
196
269
} else {
197
270
// outMapping serves as a filter to modelTensorOutput, the fields that are not specified
198
271
// in the outputMapping will not write to document
@@ -202,14 +275,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
202
275
// document field as key, model field as value
203
276
String newDocumentFieldName = entry .getKey ();
204
277
String modelOutputFieldName = entry .getValue ();
205
- if (ingestDocument .hasField (newDocumentFieldName )) {
206
- throw new IllegalArgumentException (
207
- "document already has field name "
208
- + newDocumentFieldName
209
- + ". Not allow to overwrite the same field name, please check output_map."
210
- );
278
+ if (!newOutputMapping .containsKey (newDocumentFieldName )) {
279
+ continue ;
211
280
}
212
- appendFieldValue (modelTensorOutput , modelOutputFieldName , newDocumentFieldName , ingestDocument );
281
+ appendFieldValue (mlOutput , modelOutputFieldName , newDocumentFieldName , ingestDocument );
213
282
}
214
283
}
215
284
batchPredictionListener .onResponse (null );
@@ -322,16 +391,16 @@ private void appendFieldValue(
322
391
323
392
modelOutputValue = getModelOutputValue (modelTensorOutput , modelOutputFieldName , ignoreMissing );
324
393
325
- Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
326
- ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
327
- ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
328
- List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
394
+ List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocument .getSourceAndMetadata (), newDocumentFieldName );
329
395
330
396
if (dotPathsInArray .size () == 1 ) {
331
- ValueSource ingestValue = ValueSource .wrap (modelOutputValue , scriptService );
332
- TemplateScript .Factory ingestField = ConfigurationUtils
333
- .compileTemplate (TYPE , tag , dotPathsInArray .get (0 ), dotPathsInArray .get (0 ), scriptService );
334
- ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
397
+ if (!ingestDocument .hasField (dotPathsInArray .get (0 )) || override ) {
398
+ ValueSource ingestValue = ValueSource .wrap (modelOutputValue , scriptService );
399
+ TemplateScript .Factory ingestField = ConfigurationUtils
400
+ .compileTemplate (TYPE , tag , dotPathsInArray .get (0 ), dotPathsInArray .get (0 ), scriptService );
401
+
402
+ ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
403
+ }
335
404
} else {
336
405
if (!(modelOutputValue instanceof List )) {
337
406
throw new IllegalArgumentException ("Model output is not an array, cannot assign to array in documents." );
@@ -353,18 +422,73 @@ private void appendFieldValue(
353
422
// Iterate over dotPathInArray
354
423
for (int i = 0 ; i < dotPathsInArray .size (); i ++) {
355
424
String dotPathInArray = dotPathsInArray .get (i );
356
- Object modelOutputValueInArray = modelOutputValueArray .get (i );
357
- ValueSource ingestValue = ValueSource .wrap (modelOutputValueInArray , scriptService );
358
- TemplateScript .Factory ingestField = ConfigurationUtils
359
- .compileTemplate (TYPE , tag , dotPathInArray , dotPathInArray , scriptService );
360
- ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
425
+ if (!ingestDocument .hasField (dotPathInArray ) || override ) {
426
+ Object modelOutputValueInArray = modelOutputValueArray .get (i );
427
+ ValueSource ingestValue = ValueSource .wrap (modelOutputValueInArray , scriptService );
428
+ TemplateScript .Factory ingestField = ConfigurationUtils
429
+ .compileTemplate (TYPE , tag , dotPathInArray , dotPathInArray , scriptService );
430
+ ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
431
+ }
361
432
}
362
433
}
363
434
} else {
364
435
throw new RuntimeException ("model inference output cannot be null" );
365
436
}
366
437
}
367
438
439
+ private void appendFieldValue (
440
+ MLOutput mlOutput ,
441
+ String modelOutputFieldName ,
442
+ String newDocumentFieldName ,
443
+ IngestDocument ingestDocument
444
+ ) {
445
+
446
+ if (mlOutput == null ) {
447
+ throw new RuntimeException ("model inference output is null" );
448
+ }
449
+
450
+ Object modelOutputValue = getModelOutputValue (mlOutput , modelOutputFieldName , ignoreMissing , fullResponsePath );
451
+
452
+ Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
453
+ ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
454
+ ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
455
+ List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
456
+
457
+ if (dotPathsInArray .size () == 1 ) {
458
+ ValueSource ingestValue = ValueSource .wrap (modelOutputValue , scriptService );
459
+ TemplateScript .Factory ingestField = ConfigurationUtils
460
+ .compileTemplate (TYPE , tag , dotPathsInArray .get (0 ), dotPathsInArray .get (0 ), scriptService );
461
+ ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
462
+ } else {
463
+ if (!(modelOutputValue instanceof List )) {
464
+ throw new IllegalArgumentException ("Model output is not an array, cannot assign to array in documents." );
465
+ }
466
+ List <?> modelOutputValueArray = (List <?>) modelOutputValue ;
467
+ // check length of the prediction array to be the same of the document array
468
+ if (dotPathsInArray .size () != modelOutputValueArray .size ()) {
469
+ throw new RuntimeException (
470
+ "the prediction field: "
471
+ + modelOutputFieldName
472
+ + " is an array in size of "
473
+ + modelOutputValueArray .size ()
474
+ + " but the document field array from field "
475
+ + newDocumentFieldName
476
+ + " is in size of "
477
+ + dotPathsInArray .size ()
478
+ );
479
+ }
480
+ // Iterate over dotPathInArray
481
+ for (int i = 0 ; i < dotPathsInArray .size (); i ++) {
482
+ String dotPathInArray = dotPathsInArray .get (i );
483
+ Object modelOutputValueInArray = modelOutputValueArray .get (i );
484
+ ValueSource ingestValue = ValueSource .wrap (modelOutputValueInArray , scriptService );
485
+ TemplateScript .Factory ingestField = ConfigurationUtils
486
+ .compileTemplate (TYPE , tag , dotPathInArray , dotPathInArray , scriptService );
487
+ ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
488
+ }
489
+ }
490
+ }
491
+
368
492
@ Override
369
493
public String getType () {
370
494
return TYPE ;
@@ -374,16 +498,18 @@ public static class Factory implements Processor.Factory {
374
498
375
499
private final ScriptService scriptService ;
376
500
private final Client client ;
501
+ private final NamedXContentRegistry xContentRegistry ;
377
502
378
503
/**
379
504
* Constructs a new instance of the Factory class.
380
505
*
381
506
* @param scriptService the ScriptService instance to be used by the Factory
382
507
* @param client the Client instance to be used by the Factory
383
508
*/
384
- public Factory (ScriptService scriptService , Client client ) {
509
+ public Factory (ScriptService scriptService , Client client , NamedXContentRegistry xContentRegistry ) {
385
510
this .scriptService = scriptService ;
386
511
this .client = client ;
512
+ this .xContentRegistry = xContentRegistry ;
387
513
}
388
514
389
515
/**
@@ -410,6 +536,14 @@ public MLInferenceIngestProcessor create(
410
536
int maxPredictionTask = ConfigurationUtils
411
537
.readIntProperty (TYPE , processorTag , config , MAX_PREDICTION_TASKS , DEFAULT_MAX_PREDICTION_TASKS );
412
538
boolean ignoreMissing = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , IGNORE_MISSING , false );
539
+ boolean override = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , OVERRIDE , false );
540
+ String functionName = ConfigurationUtils
541
+ .readStringProperty (TYPE , processorTag , config , FUNCTION_NAME , FunctionName .REMOTE .name ());
542
+ String modelInput = ConfigurationUtils
543
+ .readStringProperty (TYPE , processorTag , config , MODEL_INPUT , "{ \" parameters\" : ${ml_inference.parameters} }" );
544
+ boolean defaultValue = !functionName .equals ("remote" );
545
+ boolean fullResponsePath = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , FULL_RESPONSE_PATH , defaultValue );
546
+
413
547
boolean ignoreFailure = ConfigurationUtils
414
548
.readBooleanProperty (TYPE , processorTag , config , ConfigurationUtils .IGNORE_FAILURE_KEY , false );
415
549
// convert model config user input data structure to Map<String, String>
@@ -440,11 +574,16 @@ public MLInferenceIngestProcessor create(
440
574
processorTag ,
441
575
description ,
442
576
ignoreMissing ,
577
+ functionName ,
578
+ fullResponsePath ,
443
579
ignoreFailure ,
580
+ override ,
581
+ modelInput ,
444
582
scriptService ,
445
- client
583
+ client ,
584
+ xContentRegistry
446
585
);
447
586
}
448
587
}
449
588
450
- }
589
+ }
0 commit comments