6
6
7
7
import static org .opensearch .ml .processor .InferenceProcessorAttributes .*;
8
8
9
+ import java .io .IOException ;
9
10
import java .util .ArrayList ;
10
11
import java .util .Collection ;
11
12
import java .util .HashMap ;
13
+ import java .util .HashSet ;
12
14
import java .util .List ;
13
15
import java .util .Map ;
14
16
import java .util .Set ;
15
17
import java .util .function .BiConsumer ;
16
18
19
+ import org .apache .logging .log4j .LogManager ;
20
+ import org .apache .logging .log4j .Logger ;
17
21
import org .opensearch .action .ActionRequest ;
18
22
import org .opensearch .action .support .GroupedActionListener ;
19
23
import org .opensearch .client .Client ;
20
24
import org .opensearch .core .action .ActionListener ;
21
25
import org .opensearch .core .common .Strings ;
26
+ import org .opensearch .core .xcontent .NamedXContentRegistry ;
22
27
import org .opensearch .ingest .AbstractProcessor ;
23
28
import org .opensearch .ingest .ConfigurationUtils ;
24
29
import org .opensearch .ingest .IngestDocument ;
25
30
import org .opensearch .ingest .Processor ;
26
31
import org .opensearch .ingest .ValueSource ;
27
- import org .opensearch .ml .common .output .model .ModelTensorOutput ;
32
+ import org .opensearch .ml .common .FunctionName ;
33
+ import org .opensearch .ml .common .output .MLOutput ;
28
34
import org .opensearch .ml .common .transport .MLTaskResponse ;
29
35
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskAction ;
30
36
import org .opensearch .ml .common .utils .StringUtils ;
42
48
*/
43
49
public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor {
44
50
51
+ private static final Logger logger = LogManager .getLogger (MLInferenceIngestProcessor .class );
52
+
45
53
public static final String DOT_SYMBOL = "." ;
46
54
private final InferenceProcessorAttributes inferenceProcessorAttributes ;
47
55
private final boolean ignoreMissing ;
56
+ private final String functionName ;
57
+ private final boolean fullResponsePath ;
48
58
private final boolean ignoreFailure ;
59
+ private final boolean override ;
60
+ private final String modelInput ;
49
61
private final ScriptService scriptService ;
50
62
private static Client client ;
51
63
public static final String TYPE = "ml_inference" ;
52
64
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results" ;
53
65
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
54
66
// prediction outcomes, return the whole prediction outcome by skipping filtering
55
67
public static final String IGNORE_MISSING = "ignore_missing" ;
68
+ public static final String OVERRIDE = "override" ;
69
+ public static final String FUNCTION_NAME = "function_name" ;
70
+ public static final String FULL_RESPONSE_PATH = "full_response_path" ;
71
+ public static final String MODEL_INPUT = "model_input" ;
56
72
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
57
73
// it can be overwritten using max_prediction_tasks when creating processor
58
74
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10 ;
75
+ private final NamedXContentRegistry xContentRegistry ;
59
76
60
77
private Configuration suppressExceptionConfiguration = Configuration
61
78
.builder ()
@@ -71,9 +88,14 @@ protected MLInferenceIngestProcessor(
71
88
String tag ,
72
89
String description ,
73
90
boolean ignoreMissing ,
91
+ String functionName ,
92
+ boolean fullResponsePath ,
74
93
boolean ignoreFailure ,
94
+ boolean override ,
95
+ String modelInput ,
75
96
ScriptService scriptService ,
76
- Client client
97
+ Client client ,
98
+ NamedXContentRegistry xContentRegistry
77
99
) {
78
100
super (tag , description );
79
101
this .inferenceProcessorAttributes = new InferenceProcessorAttributes (
@@ -84,9 +106,14 @@ protected MLInferenceIngestProcessor(
84
106
maxPredictionTask
85
107
);
86
108
this .ignoreMissing = ignoreMissing ;
109
+ this .functionName = functionName ;
110
+ this .fullResponsePath = fullResponsePath ;
87
111
this .ignoreFailure = ignoreFailure ;
112
+ this .override = override ;
113
+ this .modelInput = modelInput ;
88
114
this .scriptService = scriptService ;
89
115
this .client = client ;
116
+ this .xContentRegistry = xContentRegistry ;
90
117
}
91
118
92
119
/**
@@ -162,10 +189,48 @@ private void processPredictions(
162
189
List <Map <String , String >> processOutputMap ,
163
190
int inputMapIndex ,
164
191
int inputMapSize
165
- ) {
192
+ ) throws IOException {
166
193
Map <String , String > modelParameters = new HashMap <>();
194
+ Map <String , String > modelConfigs = new HashMap <>();
195
+
167
196
if (inferenceProcessorAttributes .getModelConfigMaps () != null ) {
168
197
modelParameters .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
198
+ modelConfigs .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
199
+ }
200
+
201
+ Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
202
+ ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
203
+ ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
204
+
205
+ Map <String , List <String >> newOutputMapping = new HashMap <>();
206
+ if (processOutputMap != null ) {
207
+
208
+ Map <String , String > outputMapping = processOutputMap .get (inputMapIndex );
209
+ for (Map .Entry <String , String > entry : outputMapping .entrySet ()) {
210
+ String newDocumentFieldName = entry .getKey ();
211
+ List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
212
+ newOutputMapping .put (newDocumentFieldName , dotPathsInArray );
213
+ }
214
+
215
+ for (Map .Entry <String , String > entry : outputMapping .entrySet ()) {
216
+ String newDocumentFieldName = entry .getKey ();
217
+ List <String > dotPaths = newOutputMapping .get (newDocumentFieldName );
218
+
219
+ int existingFields = 0 ;
220
+ for (String path : dotPaths ) {
221
+ if (ingestDocument .hasField (path )) {
222
+ existingFields ++;
223
+ }
224
+ }
225
+ if (!override && existingFields == dotPaths .size ()) {
226
+ logger .debug ("{} already exists in the ingest document. Removing it from output mapping" , newDocumentFieldName );
227
+ newOutputMapping .remove (newDocumentFieldName );
228
+ }
229
+ }
230
+ if (newOutputMapping .size () == 0 ) {
231
+ batchPredictionListener .onResponse (null );
232
+ return ;
233
+ }
169
234
}
170
235
// when no input mapping is provided, default to read all fields from documents as model input
171
236
if (inputMapSize == 0 ) {
@@ -184,15 +249,30 @@ private void processPredictions(
184
249
}
185
250
}
186
251
187
- ActionRequest request = getRemoteModelInferenceRequest (modelParameters , inferenceProcessorAttributes .getModelId ());
252
+ Set <String > inputMapKeys = new HashSet <>(modelParameters .keySet ());
253
+ inputMapKeys .removeAll (modelConfigs .keySet ());
254
+
255
+ Map <String , String > inputMappings = new HashMap <>();
256
+ for (String k : inputMapKeys ) {
257
+ inputMappings .put (k , modelParameters .get (k ));
258
+ }
259
+ ActionRequest request = getMLModelInferenceRequest (
260
+ xContentRegistry ,
261
+ modelParameters ,
262
+ modelConfigs ,
263
+ inputMappings ,
264
+ inferenceProcessorAttributes .getModelId (),
265
+ functionName ,
266
+ modelInput
267
+ );
188
268
189
269
client .execute (MLPredictionTaskAction .INSTANCE , request , new ActionListener <>() {
190
270
191
271
@ Override
192
272
public void onResponse (MLTaskResponse mlTaskResponse ) {
193
- ModelTensorOutput modelTensorOutput = ( ModelTensorOutput ) mlTaskResponse .getOutput ();
273
+ MLOutput mlOutput = mlTaskResponse .getOutput ();
194
274
if (processOutputMap == null || processOutputMap .isEmpty ()) {
195
- appendFieldValue (modelTensorOutput , null , DEFAULT_OUTPUT_FIELD_NAME , ingestDocument );
275
+ appendFieldValue (mlOutput , null , DEFAULT_OUTPUT_FIELD_NAME , ingestDocument );
196
276
} else {
197
277
// outMapping serves as a filter to modelTensorOutput, the fields that are not specified
198
278
// in the outputMapping will not write to document
@@ -202,14 +282,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
202
282
// document field as key, model field as value
203
283
String newDocumentFieldName = entry .getKey ();
204
284
String modelOutputFieldName = entry .getValue ();
205
- if (ingestDocument .hasField (newDocumentFieldName )) {
206
- throw new IllegalArgumentException (
207
- "document already has field name "
208
- + newDocumentFieldName
209
- + ". Not allow to overwrite the same field name, please check output_map."
210
- );
285
+ if (!newOutputMapping .containsKey (newDocumentFieldName )) {
286
+ continue ;
211
287
}
212
- appendFieldValue (modelTensorOutput , modelOutputFieldName , newDocumentFieldName , ingestDocument );
288
+ appendFieldValue (mlOutput , modelOutputFieldName , newDocumentFieldName , ingestDocument );
213
289
}
214
290
}
215
291
batchPredictionListener .onResponse (null );
@@ -305,63 +381,61 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN
305
381
/**
306
382
* Appends the model output value to the specified field in the IngestDocument without modifying the source.
307
383
*
308
- * @param modelTensorOutput the ModelTensorOutput containing the model output
384
+ * @param mlOutput the MLOutput containing the model output
309
385
* @param modelOutputFieldName the name of the field in the model output
310
386
* @param newDocumentFieldName the name of the field in the IngestDocument to append the value to
311
387
* @param ingestDocument the IngestDocument to append the value to
312
388
*/
313
389
private void appendFieldValue (
314
- ModelTensorOutput modelTensorOutput ,
390
+ MLOutput mlOutput ,
315
391
String modelOutputFieldName ,
316
392
String newDocumentFieldName ,
317
393
IngestDocument ingestDocument
318
394
) {
319
- Object modelOutputValue = null ;
320
395
321
- if (modelTensorOutput .getMlModelOutputs () != null && modelTensorOutput .getMlModelOutputs ().size () > 0 ) {
396
+ if (mlOutput == null ) {
397
+ throw new RuntimeException ("model inference output is null" );
398
+ }
322
399
323
- modelOutputValue = getModelOutputValue (modelTensorOutput , modelOutputFieldName , ignoreMissing );
400
+ Object modelOutputValue = getModelOutputValue (mlOutput , modelOutputFieldName , ignoreMissing , fullResponsePath );
324
401
325
- Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
326
- ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
327
- ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
328
- List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
402
+ Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
403
+ ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
404
+ ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
405
+ List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
329
406
330
- if (dotPathsInArray .size () == 1 ) {
331
- ValueSource ingestValue = ValueSource .wrap (modelOutputValue , scriptService );
407
+ if (dotPathsInArray .size () == 1 ) {
408
+ ValueSource ingestValue = ValueSource .wrap (modelOutputValue , scriptService );
409
+ TemplateScript .Factory ingestField = ConfigurationUtils
410
+ .compileTemplate (TYPE , tag , dotPathsInArray .get (0 ), dotPathsInArray .get (0 ), scriptService );
411
+ ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
412
+ } else {
413
+ if (!(modelOutputValue instanceof List )) {
414
+ throw new IllegalArgumentException ("Model output is not an array, cannot assign to array in documents." );
415
+ }
416
+ List <?> modelOutputValueArray = (List <?>) modelOutputValue ;
417
+ // check length of the prediction array to be the same of the document array
418
+ if (dotPathsInArray .size () != modelOutputValueArray .size ()) {
419
+ throw new RuntimeException (
420
+ "the prediction field: "
421
+ + modelOutputFieldName
422
+ + " is an array in size of "
423
+ + modelOutputValueArray .size ()
424
+ + " but the document field array from field "
425
+ + newDocumentFieldName
426
+ + " is in size of "
427
+ + dotPathsInArray .size ()
428
+ );
429
+ }
430
+ // Iterate over dotPathInArray
431
+ for (int i = 0 ; i < dotPathsInArray .size (); i ++) {
432
+ String dotPathInArray = dotPathsInArray .get (i );
433
+ Object modelOutputValueInArray = modelOutputValueArray .get (i );
434
+ ValueSource ingestValue = ValueSource .wrap (modelOutputValueInArray , scriptService );
332
435
TemplateScript .Factory ingestField = ConfigurationUtils
333
- .compileTemplate (TYPE , tag , dotPathsInArray . get ( 0 ), dotPathsInArray . get ( 0 ) , scriptService );
436
+ .compileTemplate (TYPE , tag , dotPathInArray , dotPathInArray , scriptService );
334
437
ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
335
- } else {
336
- if (!(modelOutputValue instanceof List )) {
337
- throw new IllegalArgumentException ("Model output is not an array, cannot assign to array in documents." );
338
- }
339
- List <?> modelOutputValueArray = (List <?>) modelOutputValue ;
340
- // check length of the prediction array to be the same of the document array
341
- if (dotPathsInArray .size () != modelOutputValueArray .size ()) {
342
- throw new RuntimeException (
343
- "the prediction field: "
344
- + modelOutputFieldName
345
- + " is an array in size of "
346
- + modelOutputValueArray .size ()
347
- + " but the document field array from field "
348
- + newDocumentFieldName
349
- + " is in size of "
350
- + dotPathsInArray .size ()
351
- );
352
- }
353
- // Iterate over dotPathInArray
354
- for (int i = 0 ; i < dotPathsInArray .size (); i ++) {
355
- String dotPathInArray = dotPathsInArray .get (i );
356
- Object modelOutputValueInArray = modelOutputValueArray .get (i );
357
- ValueSource ingestValue = ValueSource .wrap (modelOutputValueInArray , scriptService );
358
- TemplateScript .Factory ingestField = ConfigurationUtils
359
- .compileTemplate (TYPE , tag , dotPathInArray , dotPathInArray , scriptService );
360
- ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
361
- }
362
438
}
363
- } else {
364
- throw new RuntimeException ("model inference output cannot be null" );
365
439
}
366
440
}
367
441
@@ -374,16 +448,18 @@ public static class Factory implements Processor.Factory {
374
448
375
449
private final ScriptService scriptService ;
376
450
private final Client client ;
451
+ private final NamedXContentRegistry xContentRegistry ;
377
452
378
453
/**
379
454
* Constructs a new instance of the Factory class.
380
455
*
381
456
* @param scriptService the ScriptService instance to be used by the Factory
382
457
* @param client the Client instance to be used by the Factory
383
458
*/
384
- public Factory (ScriptService scriptService , Client client ) {
459
+ public Factory (ScriptService scriptService , Client client , NamedXContentRegistry xContentRegistry ) {
385
460
this .scriptService = scriptService ;
386
461
this .client = client ;
462
+ this .xContentRegistry = xContentRegistry ;
387
463
}
388
464
389
465
/**
@@ -410,6 +486,14 @@ public MLInferenceIngestProcessor create(
410
486
int maxPredictionTask = ConfigurationUtils
411
487
.readIntProperty (TYPE , processorTag , config , MAX_PREDICTION_TASKS , DEFAULT_MAX_PREDICTION_TASKS );
412
488
boolean ignoreMissing = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , IGNORE_MISSING , false );
489
+ boolean override = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , OVERRIDE , false );
490
+ String functionName = ConfigurationUtils
491
+ .readStringProperty (TYPE , processorTag , config , FUNCTION_NAME , FunctionName .REMOTE .name ());
492
+ String modelInput = ConfigurationUtils
493
+ .readStringProperty (TYPE , processorTag , config , MODEL_INPUT , "{ \" parameters\" : ${ml_inference.parameters} }" );
494
+ boolean defaultValue = !functionName .equalsIgnoreCase ("remote" );
495
+ boolean fullResponsePath = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , FULL_RESPONSE_PATH , defaultValue );
496
+
413
497
boolean ignoreFailure = ConfigurationUtils
414
498
.readBooleanProperty (TYPE , processorTag , config , ConfigurationUtils .IGNORE_FAILURE_KEY , false );
415
499
// convert model config user input data structure to Map<String, String>
@@ -440,9 +524,14 @@ public MLInferenceIngestProcessor create(
440
524
processorTag ,
441
525
description ,
442
526
ignoreMissing ,
527
+ functionName ,
528
+ fullResponsePath ,
443
529
ignoreFailure ,
530
+ override ,
531
+ modelInput ,
444
532
scriptService ,
445
- client
533
+ client ,
534
+ xContentRegistry
446
535
);
447
536
}
448
537
}
0 commit comments