21
21
import java .util .List ;
22
22
import java .util .Map ;
23
23
import java .util .Set ;
24
+ import java .util .concurrent .atomic .AtomicBoolean ;
24
25
25
26
import org .apache .logging .log4j .LogManager ;
26
27
import org .apache .logging .log4j .Logger ;
28
+ import org .opensearch .OpenSearchStatusException ;
27
29
import org .opensearch .action .ActionRequest ;
28
30
import org .opensearch .action .search .SearchRequest ;
29
31
import org .opensearch .action .search .SearchResponse ;
33
35
import org .opensearch .common .xcontent .XContentHelper ;
34
36
import org .opensearch .core .action .ActionListener ;
35
37
import org .opensearch .core .common .bytes .BytesReference ;
38
+ import org .opensearch .core .rest .RestStatus ;
36
39
import org .opensearch .core .xcontent .MediaType ;
37
40
import org .opensearch .core .xcontent .NamedXContentRegistry ;
38
41
import org .opensearch .core .xcontent .XContentBuilder ;
39
42
import org .opensearch .ingest .ConfigurationUtils ;
40
43
import org .opensearch .ml .common .FunctionName ;
44
+ import org .opensearch .ml .common .exception .MLResourceNotFoundException ;
41
45
import org .opensearch .ml .common .output .MLOutput ;
42
46
import org .opensearch .ml .common .transport .MLTaskResponse ;
43
47
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskAction ;
44
48
import org .opensearch .ml .common .utils .StringUtils ;
45
49
import org .opensearch .ml .utils .MapUtils ;
50
+ import org .opensearch .ml .utils .SearchResponseUtil ;
46
51
import org .opensearch .search .SearchHit ;
47
52
import org .opensearch .search .pipeline .AbstractProcessor ;
48
53
import org .opensearch .search .pipeline .PipelineProcessingContext ;
@@ -125,9 +130,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
125
130
/**
126
131
* Processes the search response asynchronously by rewriting the documents with the inference results.
127
132
*
128
- * @param request the search request
129
- * @param response the search response
130
- * @param responseContext the pipeline processing context
133
+ * By default, it processes multiple documents in a single prediction through the rewriteResponseDocuments method.
134
+ * However, when processing one document per inference, it separates the N-hits search response into N one-hit search responses,
135
+ * executes the same rewriteResponseDocument method for each one-hit search response,
136
+ * and after receiving N one-hit search responses with inference results,
137
+ * it combines them back into a single N-hits search response.
138
+ *
139
+ * @param request the search request
140
+ * @param response the search response
141
+ * @param responseContext the pipeline processing context
131
142
* @param responseListener the listener to be notified when the response is processed
132
143
*/
133
144
@ Override
@@ -144,20 +155,130 @@ public void processResponseAsync(
144
155
responseListener .onResponse (response );
145
156
return ;
146
157
}
147
- rewriteResponseDocuments (response , responseListener );
158
+
159
+ // if many to one, run rewriteResponseDocuments
160
+ if (!oneToOne ) {
161
+ rewriteResponseDocuments (response , responseListener );
162
+ } else {
163
+ // if one to one, make one hit search response and run rewriteResponseDocuments
164
+ GroupedActionListener <SearchResponse > combineResponseListener = getCombineResponseGroupedActionListener (
165
+ response ,
166
+ responseListener ,
167
+ hits
168
+ );
169
+ AtomicBoolean isOneHitListenerFailed = new AtomicBoolean (false );
170
+ ;
171
+ for (SearchHit hit : hits ) {
172
+ SearchHit [] newHits = new SearchHit [1 ];
173
+ newHits [0 ] = hit ;
174
+ SearchResponse oneHitResponse = SearchResponseUtil .replaceHits (newHits , response );
175
+ ActionListener <SearchResponse > oneHitListener = getOneHitListener (combineResponseListener , isOneHitListenerFailed );
176
+ rewriteResponseDocuments (oneHitResponse , oneHitListener );
177
+ // if any OneHitListener failure, try stop the rest of the predictions
178
+ if (isOneHitListenerFailed .get ()) {
179
+ break ;
180
+ }
181
+ }
182
+ }
183
+
148
184
} catch (Exception e ) {
149
185
if (ignoreFailure ) {
150
186
responseListener .onResponse (response );
151
187
} else {
152
188
responseListener .onFailure (e );
189
+ if (e instanceof OpenSearchStatusException ) {
190
+ responseListener
191
+ .onFailure (
192
+ new OpenSearchStatusException (
193
+ "Failed to process response: " + e .getMessage (),
194
+ RestStatus .fromCode (((OpenSearchStatusException ) e ).status ().getStatus ())
195
+ )
196
+ );
197
+ } else if (e instanceof MLResourceNotFoundException ) {
198
+ responseListener
199
+ .onFailure (new OpenSearchStatusException ("Failed to process response: " + e .getMessage (), RestStatus .NOT_FOUND ));
200
+ } else {
201
+ responseListener .onFailure (e );
202
+ }
153
203
}
154
204
}
155
205
}
156
206
207
+ /**
208
+ * Creates an ActionListener for a single SearchResponse that delegates its
209
+ * onResponse and onFailure callbacks to a GroupedActionListener.
210
+ *
211
+ * @param combineResponseListener The GroupedActionListener to which the
212
+ * onResponse and onFailure callbacks will be
213
+ * delegated.
214
+ * @param isOneHitListenerFailed
215
+ * @return An ActionListener that delegates its callbacks to the provided
216
+ * GroupedActionListener.
217
+ */
218
+ private static ActionListener <SearchResponse > getOneHitListener (
219
+ GroupedActionListener <SearchResponse > combineResponseListener ,
220
+ AtomicBoolean isOneHitListenerFailed
221
+ ) {
222
+ ActionListener <SearchResponse > oneHitListener = new ActionListener <>() {
223
+ @ Override
224
+ public void onResponse (SearchResponse response ) {
225
+ combineResponseListener .onResponse (response );
226
+ }
227
+
228
+ @ Override
229
+ public void onFailure (Exception e ) {
230
+ // if any OneHitListener failure, try stop the rest of the predictions and return
231
+ isOneHitListenerFailed .compareAndSet (false , true );
232
+ combineResponseListener .onFailure (e );
233
+ }
234
+ };
235
+ return oneHitListener ;
236
+ }
237
+
238
+ /**
239
+ * Creates a GroupedActionListener that combines the SearchResponses from individual hits
240
+ * and constructs a new SearchResponse with the combined hits.
241
+ *
242
+ * @param response The original SearchResponse containing the hits to be processed.
243
+ * @param responseListener The ActionListener to be notified with the combined SearchResponse.
244
+ * @param hits The array of SearchHits to be processed.
245
+ * @return A GroupedActionListener that combines the SearchResponses and constructs a new SearchResponse.
246
+ */
247
+ private GroupedActionListener <SearchResponse > getCombineResponseGroupedActionListener (
248
+ SearchResponse response ,
249
+ ActionListener <SearchResponse > responseListener ,
250
+ SearchHit [] hits
251
+ ) {
252
+ GroupedActionListener <SearchResponse > combineResponseListener = new GroupedActionListener <>(new ActionListener <>() {
253
+ @ Override
254
+ public void onResponse (Collection <SearchResponse > responseMapCollection ) {
255
+ SearchHit [] combinedHits = new SearchHit [hits .length ];
256
+ int i = 0 ;
257
+ for (SearchResponse OneHitResponseAfterInference : responseMapCollection ) {
258
+ SearchHit [] hitsAfterInference = OneHitResponseAfterInference .getHits ().getHits ();
259
+ combinedHits [i ] = hitsAfterInference [0 ];
260
+ i ++;
261
+ }
262
+ SearchResponse oneToOneInferenceSearchResponse = SearchResponseUtil .replaceHits (combinedHits , response );
263
+ responseListener .onResponse (oneToOneInferenceSearchResponse );
264
+ }
265
+
266
+ @ Override
267
+ public void onFailure (Exception e ) {
268
+ if (ignoreFailure ) {
269
+ responseListener .onResponse (response );
270
+ } else {
271
+ responseListener .onFailure (e );
272
+ }
273
+ }
274
+ }, hits .length );
275
+ return combineResponseListener ;
276
+ }
277
+
157
278
/**
158
279
* Rewrite the documents in the search response with the inference results.
159
280
*
160
- * @param response the search response
281
+ * @param response the search response
161
282
* @param responseListener the listener to be notified when the response is processed
162
283
* @throws IOException if an I/O error occurs during the rewriting process
163
284
*/
@@ -168,27 +289,23 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
168
289
169
290
// hitCountInPredictions keeps track of the count of hit that have the required input fields for each round of prediction
170
291
Map <Integer , Integer > hitCountInPredictions = new HashMap <>();
171
- if (!oneToOne ) {
172
- ActionListener <Map <Integer , MLOutput >> rewriteResponseListener = createRewriteResponseListenerManyToOne (
173
- response ,
174
- responseListener ,
175
- processInputMap ,
176
- processOutputMap ,
177
- hitCountInPredictions
178
- );
179
292
180
- GroupedActionListener <Map <Integer , MLOutput >> batchPredictionListener = createBatchPredictionListenerManyToOne (
181
- rewriteResponseListener ,
182
- inputMapSize
183
- );
184
- SearchHit [] hits = response .getHits ().getHits ();
185
- for (int inputMapIndex = 0 ; inputMapIndex < max (inputMapSize , 1 ); inputMapIndex ++) {
186
- processPredictionsManyToOne (hits , processInputMap , inputMapIndex , batchPredictionListener , hitCountInPredictions );
187
- }
188
- } else {
189
- responseListener .onFailure (new IllegalArgumentException ("one to one prediction is not supported yet." ));
190
- }
293
+ ActionListener <Map <Integer , MLOutput >> rewriteResponseListener = createRewriteResponseListener (
294
+ response ,
295
+ responseListener ,
296
+ processInputMap ,
297
+ processOutputMap ,
298
+ hitCountInPredictions
299
+ );
191
300
301
+ GroupedActionListener <Map <Integer , MLOutput >> batchPredictionListener = createBatchPredictionListener (
302
+ rewriteResponseListener ,
303
+ inputMapSize
304
+ );
305
+ SearchHit [] hits = response .getHits ().getHits ();
306
+ for (int inputMapIndex = 0 ; inputMapIndex < max (inputMapSize , 1 ); inputMapIndex ++) {
307
+ processPredictions (hits , processInputMap , inputMapIndex , batchPredictionListener , hitCountInPredictions );
308
+ }
192
309
}
193
310
194
311
/**
@@ -201,7 +318,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
201
318
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
202
319
* @throws IOException if an I/O error occurs during the prediction process
203
320
*/
204
- private void processPredictionsManyToOne (
321
+ private void processPredictions (
205
322
SearchHit [] hits ,
206
323
List <Map <String , String >> processInputMap ,
207
324
int inputMapIndex ,
@@ -242,7 +359,7 @@ private void processPredictionsManyToOne(
242
359
Object documentValue = JsonPath .using (configuration ).parse (documentJson ).read (documentFieldName );
243
360
if (documentValue != null ) {
244
361
// when not existed in the map, add into the modelInputParameters map
245
- updateModelInputParametersManyToOne (modelInputParameters , modelInputFieldName , documentValue );
362
+ updateModelInputParameters (modelInputParameters , modelInputFieldName , documentValue );
246
363
}
247
364
}
248
365
} else { // when document does not contain the documentFieldName, skip when ignoreMissing
@@ -263,8 +380,7 @@ private void processPredictionsManyToOne(
263
380
Object documentValue = entry .getValue ();
264
381
265
382
// when not existed in the map, add into the modelInputParameters map
266
- updateModelInputParametersManyToOne (modelInputParameters , modelInputFieldName , documentValue );
267
-
383
+ updateModelInputParameters (modelInputParameters , modelInputFieldName , documentValue );
268
384
}
269
385
}
270
386
}
@@ -306,18 +422,28 @@ public void onFailure(Exception e) {
306
422
});
307
423
}
308
424
309
- private void updateModelInputParametersManyToOne (
310
- Map <String , Object > modelInputParameters ,
311
- String modelInputFieldName ,
312
- Object documentValue
313
- ) {
314
- if (!modelInputParameters .containsKey (modelInputFieldName )) {
315
- List <Object > documentValueList = new ArrayList <>();
316
- documentValueList .add (documentValue );
317
- modelInputParameters .put (modelInputFieldName , documentValueList );
425
+ /**
426
+ * Updates the model input parameters map with the given document value.
427
+ * If the setting is one-to-one,
428
+ * simply put the document value in the map
429
+ * If the setting is many-to-one,
430
+ * create a new list and add the document value
431
+ * @param modelInputParameters The map containing the model input parameters.
432
+ * @param modelInputFieldName The name of the model input field.
433
+ * @param documentValue The value from the document that needs to be added to the model input parameters.
434
+ */
435
+ private void updateModelInputParameters (Map <String , Object > modelInputParameters , String modelInputFieldName , Object documentValue ) {
436
+ if (!this .oneToOne ) {
437
+ if (!modelInputParameters .containsKey (modelInputFieldName )) {
438
+ List <Object > documentValueList = new ArrayList <>();
439
+ documentValueList .add (documentValue );
440
+ modelInputParameters .put (modelInputFieldName , documentValueList );
441
+ } else {
442
+ List <Object > valueList = ((List ) modelInputParameters .get (modelInputFieldName ));
443
+ valueList .add (documentValue );
444
+ }
318
445
} else {
319
- List <Object > valueList = ((List ) modelInputParameters .get (modelInputFieldName ));
320
- valueList .add (documentValue );
446
+ modelInputParameters .put (modelInputFieldName , documentValue );
321
447
}
322
448
}
323
449
@@ -328,7 +454,7 @@ private void updateModelInputParametersManyToOne(
328
454
* @param inputMapSize the size of the input map
329
455
* @return a grouped action listener for batch predictions
330
456
*/
331
- private GroupedActionListener <Map <Integer , MLOutput >> createBatchPredictionListenerManyToOne (
457
+ private GroupedActionListener <Map <Integer , MLOutput >> createBatchPredictionListener (
332
458
ActionListener <Map <Integer , MLOutput >> rewriteResponseListener ,
333
459
int inputMapSize
334
460
) {
@@ -353,14 +479,14 @@ public void onFailure(Exception e) {
353
479
/**
354
480
* Creates an action listener for rewriting the response with the inference results.
355
481
*
356
- * @param response the search response
357
- * @param responseListener the listener to be notified when the response is processed
358
- * @param processInputMap the list of input mappings
359
- * @param processOutputMap the list of output mappings
360
- * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
482
+ * @param response the search response
483
+ * @param responseListener the listener to be notified when the response is processed
484
+ * @param processInputMap the list of input mappings
485
+ * @param processOutputMap the list of output mappings
486
+ * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
361
487
* @return an action listener for rewriting the response with the inference results
362
488
*/
363
- private ActionListener <Map <Integer , MLOutput >> createRewriteResponseListenerManyToOne (
489
+ private ActionListener <Map <Integer , MLOutput >> createRewriteResponseListener (
364
490
SearchResponse response ,
365
491
ActionListener <SearchResponse > responseListener ,
366
492
List <Map <String , String >> processInputMap ,
@@ -392,7 +518,7 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
392
518
Map <String , String > outputMapping = getDefaultOutputMapping (mappingIndex , processOutputMap );
393
519
394
520
boolean isModelInputMissing = false ;
395
- if (processInputMap != null ) {
521
+ if (processInputMap != null && ! processInputMap . isEmpty () ) {
396
522
isModelInputMissing = checkIsModelInputMissing (document , inputMapping );
397
523
}
398
524
if (!isModelInputMissing ) {
@@ -499,10 +625,10 @@ private boolean checkIsModelInputMissing(Map<String, Object> document, Map<Strin
499
625
* <p>If the processOutputMap is not null and not empty, the mapping at the specified mappingIndex
500
626
* is returned.
501
627
*
502
- * @param mappingIndex the index of the mapping to retrieve from the processOutputMap
628
+ * @param mappingIndex the index of the mapping to retrieve from the processOutputMap
503
629
* @param processOutputMap the list of output mappings, can be null or empty
504
630
* @return a Map containing the output mapping, either the default mapping or the mapping at the
505
- * specified index
631
+ * specified index
506
632
*/
507
633
private static Map <String , String > getDefaultOutputMapping (Integer mappingIndex , List <Map <String , String >> processOutputMap ) {
508
634
Map <String , String > outputMapping ;
@@ -524,11 +650,11 @@ private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex,
524
650
* <p>If the processInputMap is not null and not empty, the mapping at the specified mappingIndex
525
651
* is returned.
526
652
*
527
- * @param sourceAsMap the source map containing the input data
528
- * @param mappingIndex the index of the mapping to retrieve from the processInputMap
653
+ * @param sourceAsMap the source map containing the input data
654
+ * @param mappingIndex the index of the mapping to retrieve from the processInputMap
529
655
* @param processInputMap the list of input mappings, can be null or empty
530
656
* @return a Map containing the input mapping, either the mapping extracted from sourceAsMap or
531
- * the mapping at the specified index
657
+ * the mapping at the specified index
532
658
*/
533
659
private static Map <String , String > getDefaultInputMapping (
534
660
Map <String , Object > sourceAsMap ,
0 commit comments