26
26
import org .apache .commons .lang3 .StringUtils ;
27
27
import org .apache .commons .lang3 .tuple .ImmutablePair ;
28
28
import org .apache .commons .lang3 .tuple .Pair ;
29
+ import org .opensearch .action .get .MultiGetItemResponse ;
30
+ import org .opensearch .action .get .MultiGetRequest ;
29
31
import org .opensearch .common .collect .Tuple ;
30
32
import org .opensearch .core .action .ActionListener ;
31
33
import org .opensearch .core .common .util .CollectionUtils ;
@@ -54,6 +56,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
54
56
55
57
public static final String MODEL_ID_FIELD = "model_id" ;
56
58
public static final String FIELD_MAP_FIELD = "field_map" ;
59
+ public static final String INDEX_FIELD = "_index" ;
60
+ public static final String ID_FIELD = "_id" ;
57
61
private static final BiFunction <Object , Object , Object > REMAPPING_FUNCTION = (v1 , v2 ) -> {
58
62
if (v1 instanceof Collection && v2 instanceof Collection ) {
59
63
((Collection ) v1 ).addAll ((Collection ) v2 );
@@ -169,49 +173,91 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
169
173
*/
170
174
abstract void doBatchExecute (List <String > inferenceList , Consumer <List <?>> handler , Consumer <Exception > onException );
171
175
176
+ /**
177
+ * This is the function which does actual inference work for subBatchExecute interface.
178
+ * @param ingestDocumentWrappers a list of IngestDocuments in a batch.
179
+ * @param handler a callback handler to handle inference results which is a list of objects.
180
+ */
172
181
@ Override
173
182
public void subBatchExecute (List <IngestDocumentWrapper > ingestDocumentWrappers , Consumer <List <IngestDocumentWrapper >> handler ) {
174
- if (CollectionUtils .isEmpty (ingestDocumentWrappers )) {
175
- handler .accept (Collections .emptyList ());
176
- return ;
177
- }
183
+ try {
184
+ if (CollectionUtils .isEmpty (ingestDocumentWrappers )) {
185
+ handler .accept (Collections .emptyList ());
186
+ return ;
187
+ }
178
188
179
- List <DataForInference > dataForInferences = getDataForInference (ingestDocumentWrappers );
180
- List <String > inferenceList = constructInferenceTexts (dataForInferences );
181
- if (inferenceList .isEmpty ()) {
189
+ List <DataForInference > dataForInferences = getDataForInference (ingestDocumentWrappers );
190
+ List <String > inferenceList = constructInferenceTexts (dataForInferences );
191
+ if (inferenceList .isEmpty ()) {
192
+ handler .accept (ingestDocumentWrappers );
193
+ return ;
194
+ }
195
+ doSubBatchExecute (ingestDocumentWrappers , inferenceList , dataForInferences , handler );
196
+ } catch (Exception e ) {
197
+ updateWithExceptions (ingestDocumentWrappers , e );
182
198
handler .accept (ingestDocumentWrappers );
183
- return ;
184
199
}
185
- Tuple <List <String >, Map <Integer , Integer >> sortedResult = sortByLengthAndReturnOriginalOrder (inferenceList );
186
- inferenceList = sortedResult .v1 ();
187
- Map <Integer , Integer > originalOrder = sortedResult .v2 ();
188
- doBatchExecute (inferenceList , results -> {
189
- int startIndex = 0 ;
190
- results = restoreToOriginalOrder (results , originalOrder );
191
- for (DataForInference dataForInference : dataForInferences ) {
192
- if (dataForInference .getIngestDocumentWrapper ().getException () != null
193
- || CollectionUtils .isEmpty (dataForInference .getInferenceList ())) {
194
- continue ;
200
+ }
201
+
202
+ /**
203
+ * This is a helper function for subBatchExecute, which invokes doBatchExecute for given inference list.
204
+ * @param ingestDocumentWrappers a list of IngestDocuments in a batch.
205
+ * @param inferenceList a list of String for inference.
206
+ * @param dataForInferences a list of data for inference, which includes ingestDocumentWrapper, processMap, inferenceList.
207
+ * @param handler a callback handler to handle inference results which is a list of objects.
208
+ */
209
+ protected void doSubBatchExecute (
210
+ List <IngestDocumentWrapper > ingestDocumentWrappers ,
211
+ List <String > inferenceList ,
212
+ List <DataForInference > dataForInferences ,
213
+ Consumer <List <IngestDocumentWrapper >> handler
214
+ ) {
215
+ try {
216
+ Tuple <List <String >, Map <Integer , Integer >> sortedResult = sortByLengthAndReturnOriginalOrder (inferenceList );
217
+ inferenceList = sortedResult .v1 ();
218
+ Map <Integer , Integer > originalOrder = sortedResult .v2 ();
219
+ doBatchExecute (inferenceList , results -> {
220
+ try {
221
+ int startIndex = 0 ;
222
+ results = restoreToOriginalOrder (results , originalOrder );
223
+ for (DataForInference dataForInference : dataForInferences ) {
224
+ if (dataForInference .getIngestDocumentWrapper ().getException () != null
225
+ || CollectionUtils .isEmpty (dataForInference .getInferenceList ())) {
226
+ continue ;
227
+ }
228
+ List <?> inferenceResults = results .subList (startIndex , startIndex + dataForInference .getInferenceList ().size ());
229
+ startIndex += dataForInference .getInferenceList ().size ();
230
+ setVectorFieldsToDocument (
231
+ dataForInference .getIngestDocumentWrapper ().getIngestDocument (),
232
+ dataForInference .getProcessMap (),
233
+ inferenceResults
234
+ );
235
+ }
236
+ handler .accept (ingestDocumentWrappers );
237
+ } catch (Exception e ) {
238
+ updateWithExceptions (ingestDocumentWrappers , e );
239
+ handler .accept (ingestDocumentWrappers );
195
240
}
196
- List <?> inferenceResults = results .subList (startIndex , startIndex + dataForInference .getInferenceList ().size ());
197
- startIndex += dataForInference .getInferenceList ().size ();
198
- setVectorFieldsToDocument (
199
- dataForInference .getIngestDocumentWrapper ().getIngestDocument (),
200
- dataForInference .getProcessMap (),
201
- inferenceResults
202
- );
203
- }
204
- handler .accept (ingestDocumentWrappers );
205
- }, exception -> {
206
- for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
207
- // The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
208
- // set exception to IngestDocumentWrapper which doesn't have exception before.
209
- if (ingestDocumentWrapper .getException () == null ) {
210
- ingestDocumentWrapper .update (ingestDocumentWrapper .getIngestDocument (), exception );
241
+ }, exception -> {
242
+ try {
243
+ for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
244
+ // The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
245
+ // set exception to IngestDocumentWrapper which doesn't have exception before.
246
+ if (ingestDocumentWrapper .getException () == null ) {
247
+ ingestDocumentWrapper .update (ingestDocumentWrapper .getIngestDocument (), exception );
248
+ }
249
+ }
250
+ handler .accept (ingestDocumentWrappers );
251
+ } catch (Exception e ) {
252
+ updateWithExceptions (ingestDocumentWrappers , e );
253
+ handler .accept (ingestDocumentWrappers );
211
254
}
212
- }
255
+
256
+ });
257
+ } catch (Exception e ) {
258
+ updateWithExceptions (ingestDocumentWrappers , e );
213
259
handler .accept (ingestDocumentWrappers );
214
- });
260
+ }
215
261
}
216
262
217
263
private Tuple <List <String >, Map <Integer , Integer >> sortByLengthAndReturnOriginalOrder (List <String > inferenceList ) {
@@ -238,7 +284,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238
284
return sortedResults ;
239
285
}
240
286
241
- private List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
287
+ protected List <String > constructInferenceTexts (List <DataForInference > dataForInferences ) {
242
288
List <String > inferenceTexts = new ArrayList <>();
243
289
for (DataForInference dataForInference : dataForInferences ) {
244
290
if (dataForInference .getIngestDocumentWrapper ().getException () != null
@@ -250,7 +296,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250
296
return inferenceTexts ;
251
297
}
252
298
253
- private List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
299
+ protected List <DataForInference > getDataForInference (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
254
300
List <DataForInference > dataForInferences = new ArrayList <>();
255
301
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
256
302
Map <String , Object > processMap = null ;
@@ -272,7 +318,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272
318
273
319
@ Getter
274
320
@ AllArgsConstructor
275
- private static class DataForInference {
321
+ protected static class DataForInference {
276
322
private final IngestDocumentWrapper ingestDocumentWrapper ;
277
323
private final Map <String , Object > processMap ;
278
324
private final List <String > inferenceList ;
@@ -415,6 +461,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415
461
nlpResult .forEach (ingestDocument ::setFieldValue );
416
462
}
417
463
464
+ /**
465
+ * This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
466
+ * @param ingestDocumentWrappers, list of ingest documents
467
+ * */
468
+ protected MultiGetRequest buildMultiGetRequest (List <IngestDocumentWrapper > ingestDocumentWrappers ) {
469
+ MultiGetRequest multiGetRequest = new MultiGetRequest ();
470
+ for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
471
+ Object index = ingestDocumentWrapper .getIngestDocument ().getSourceAndMetadata ().get (INDEX_FIELD );
472
+ Object id = ingestDocumentWrapper .getIngestDocument ().getSourceAndMetadata ().get (ID_FIELD );
473
+ if (Objects .nonNull (index ) && Objects .nonNull (id )) {
474
+ multiGetRequest .add (index .toString (), id .toString ());
475
+ }
476
+ }
477
+ return multiGetRequest ;
478
+ }
479
+
480
+ /**
481
+ * This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
482
+ * @param multiGetItemResponses, array of responses from Multi Get Request
483
+ * */
484
+ protected Map <String , Map <String , Object >> createDocumentMap (MultiGetItemResponse [] multiGetItemResponses ) {
485
+ Map <String , Map <String , Object >> existingDocuments = new HashMap <>();
486
+ for (MultiGetItemResponse item : multiGetItemResponses ) {
487
+ String id = item .getId ();
488
+ Map <String , Object > existingDocument = item .getResponse ().getSourceAsMap ();
489
+ existingDocuments .put (id , existingDocument );
490
+ }
491
+ return existingDocuments ;
492
+ }
493
+
418
494
@ SuppressWarnings ({ "unchecked" })
419
495
@ VisibleForTesting
420
496
Map <String , Object > buildNLPResult (Map <String , Object > processorMap , List <?> results , Map <String , Object > sourceAndMetadataMap ) {
@@ -504,6 +580,13 @@ private void processMapEntryValue(
504
580
}
505
581
}
506
582
583
+ // This method updates each ingestDocument with exceptions
584
+ protected void updateWithExceptions (List <IngestDocumentWrapper > ingestDocumentWrappers , Exception e ) {
585
+ for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers ) {
586
+ ingestDocumentWrapper .update (ingestDocumentWrapper .getIngestDocument (), e );
587
+ }
588
+ }
589
+
507
590
private void processMapEntryValue (
508
591
List <?> results ,
509
592
IndexWrapper indexWrapper ,
@@ -582,7 +665,7 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582
665
List <Map <String , Object >> keyToResult = new ArrayList <>();
583
666
sourceValue .stream ()
584
667
.filter (Objects ::nonNull ) // explicit null check is required since sourceValue can contain null values in cases where
585
- // sourceValue has been filtered
668
+ // sourceValue has been filtered
586
669
.forEachOrdered (x -> keyToResult .add (ImmutableMap .of (listTypeNestedMapKey , results .get (indexWrapper .index ++))));
587
670
return keyToResult ;
588
671
}
0 commit comments