Skip to content

Commit 71ff4e2

Browse files
Support one_to_one in ML Inference Search Response Processor (opensearch-project#2801) (opensearch-project#2843)
* add one document to one prediction support Signed-off-by: Mingshi Liu <mingshl@amazon.com> * rephrase javadoc Signed-off-by: Mingshi Liu <mingshl@amazon.com> * use OpenSearchStatusException in error handling Signed-off-by: Mingshi Liu <mingshl@amazon.com> * fix message Signed-off-by: Mingshi Liu <mingshl@amazon.com> * add more tests Signed-off-by: Mingshi Liu <mingshl@amazon.com> * handle different exceptions properly Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com> (cherry picked from commit 2a33c65) Co-authored-by: Mingshi Liu <mingshl@amazon.com>
1 parent 45a0412 commit 71ff4e2

File tree

3 files changed

+2211
-105
lines changed

3 files changed

+2211
-105
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java

+178-52
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.Set;
24+
import java.util.concurrent.atomic.AtomicBoolean;
2425

2526
import org.apache.logging.log4j.LogManager;
2627
import org.apache.logging.log4j.Logger;
28+
import org.opensearch.OpenSearchStatusException;
2729
import org.opensearch.action.ActionRequest;
2830
import org.opensearch.action.search.SearchRequest;
2931
import org.opensearch.action.search.SearchResponse;
@@ -33,16 +35,19 @@
3335
import org.opensearch.common.xcontent.XContentHelper;
3436
import org.opensearch.core.action.ActionListener;
3537
import org.opensearch.core.common.bytes.BytesReference;
38+
import org.opensearch.core.rest.RestStatus;
3639
import org.opensearch.core.xcontent.MediaType;
3740
import org.opensearch.core.xcontent.NamedXContentRegistry;
3841
import org.opensearch.core.xcontent.XContentBuilder;
3942
import org.opensearch.ingest.ConfigurationUtils;
4043
import org.opensearch.ml.common.FunctionName;
44+
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
4145
import org.opensearch.ml.common.output.MLOutput;
4246
import org.opensearch.ml.common.transport.MLTaskResponse;
4347
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
4448
import org.opensearch.ml.common.utils.StringUtils;
4549
import org.opensearch.ml.utils.MapUtils;
50+
import org.opensearch.ml.utils.SearchResponseUtil;
4651
import org.opensearch.search.SearchHit;
4752
import org.opensearch.search.pipeline.AbstractProcessor;
4853
import org.opensearch.search.pipeline.PipelineProcessingContext;
@@ -125,9 +130,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
125130
/**
126131
* Processes the search response asynchronously by rewriting the documents with the inference results.
127132
*
128-
* @param request the search request
129-
* @param response the search response
130-
* @param responseContext the pipeline processing context
133+
* By default, it processes multiple documents in a single prediction through the rewriteResponseDocuments method.
134+
* However, when processing one document per inference, it separates the N-hits search response into N one-hit search responses,
135+
* executes the same rewriteResponseDocument method for each one-hit search response,
136+
* and after receiving N one-hit search responses with inference results,
137+
* it combines them back into a single N-hits search response.
138+
*
139+
* @param request the search request
140+
* @param response the search response
141+
* @param responseContext the pipeline processing context
131142
* @param responseListener the listener to be notified when the response is processed
132143
*/
133144
@Override
@@ -144,20 +155,130 @@ public void processResponseAsync(
144155
responseListener.onResponse(response);
145156
return;
146157
}
147-
rewriteResponseDocuments(response, responseListener);
158+
159+
// if many to one, run rewriteResponseDocuments
160+
if (!oneToOne) {
161+
rewriteResponseDocuments(response, responseListener);
162+
} else {
163+
// if one to one, make one hit search response and run rewriteResponseDocuments
164+
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
165+
response,
166+
responseListener,
167+
hits
168+
);
169+
AtomicBoolean isOneHitListenerFailed = new AtomicBoolean(false);
170+
;
171+
for (SearchHit hit : hits) {
172+
SearchHit[] newHits = new SearchHit[1];
173+
newHits[0] = hit;
174+
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
175+
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed);
176+
rewriteResponseDocuments(oneHitResponse, oneHitListener);
177+
// if any OneHitListener failure, try stop the rest of the predictions
178+
if (isOneHitListenerFailed.get()) {
179+
break;
180+
}
181+
}
182+
}
183+
148184
} catch (Exception e) {
149185
if (ignoreFailure) {
150186
responseListener.onResponse(response);
151187
} else {
152188
responseListener.onFailure(e);
189+
if (e instanceof OpenSearchStatusException) {
190+
responseListener
191+
.onFailure(
192+
new OpenSearchStatusException(
193+
"Failed to process response: " + e.getMessage(),
194+
RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus())
195+
)
196+
);
197+
} else if (e instanceof MLResourceNotFoundException) {
198+
responseListener
199+
.onFailure(new OpenSearchStatusException("Failed to process response: " + e.getMessage(), RestStatus.NOT_FOUND));
200+
} else {
201+
responseListener.onFailure(e);
202+
}
153203
}
154204
}
155205
}
156206

207+
/**
208+
* Creates an ActionListener for a single SearchResponse that delegates its
209+
* onResponse and onFailure callbacks to a GroupedActionListener.
210+
*
211+
* @param combineResponseListener The GroupedActionListener to which the
212+
* onResponse and onFailure callbacks will be
213+
* delegated.
214+
* @param isOneHitListenerFailed
215+
* @return An ActionListener that delegates its callbacks to the provided
216+
* GroupedActionListener.
217+
*/
218+
private static ActionListener<SearchResponse> getOneHitListener(
219+
GroupedActionListener<SearchResponse> combineResponseListener,
220+
AtomicBoolean isOneHitListenerFailed
221+
) {
222+
ActionListener<SearchResponse> oneHitListener = new ActionListener<>() {
223+
@Override
224+
public void onResponse(SearchResponse response) {
225+
combineResponseListener.onResponse(response);
226+
}
227+
228+
@Override
229+
public void onFailure(Exception e) {
230+
// if any OneHitListener failure, try stop the rest of the predictions and return
231+
isOneHitListenerFailed.compareAndSet(false, true);
232+
combineResponseListener.onFailure(e);
233+
}
234+
};
235+
return oneHitListener;
236+
}
237+
238+
/**
239+
* Creates a GroupedActionListener that combines the SearchResponses from individual hits
240+
* and constructs a new SearchResponse with the combined hits.
241+
*
242+
* @param response The original SearchResponse containing the hits to be processed.
243+
* @param responseListener The ActionListener to be notified with the combined SearchResponse.
244+
* @param hits The array of SearchHits to be processed.
245+
* @return A GroupedActionListener that combines the SearchResponses and constructs a new SearchResponse.
246+
*/
247+
private GroupedActionListener<SearchResponse> getCombineResponseGroupedActionListener(
248+
SearchResponse response,
249+
ActionListener<SearchResponse> responseListener,
250+
SearchHit[] hits
251+
) {
252+
GroupedActionListener<SearchResponse> combineResponseListener = new GroupedActionListener<>(new ActionListener<>() {
253+
@Override
254+
public void onResponse(Collection<SearchResponse> responseMapCollection) {
255+
SearchHit[] combinedHits = new SearchHit[hits.length];
256+
int i = 0;
257+
for (SearchResponse OneHitResponseAfterInference : responseMapCollection) {
258+
SearchHit[] hitsAfterInference = OneHitResponseAfterInference.getHits().getHits();
259+
combinedHits[i] = hitsAfterInference[0];
260+
i++;
261+
}
262+
SearchResponse oneToOneInferenceSearchResponse = SearchResponseUtil.replaceHits(combinedHits, response);
263+
responseListener.onResponse(oneToOneInferenceSearchResponse);
264+
}
265+
266+
@Override
267+
public void onFailure(Exception e) {
268+
if (ignoreFailure) {
269+
responseListener.onResponse(response);
270+
} else {
271+
responseListener.onFailure(e);
272+
}
273+
}
274+
}, hits.length);
275+
return combineResponseListener;
276+
}
277+
157278
/**
158279
* Rewrite the documents in the search response with the inference results.
159280
*
160-
* @param response the search response
281+
* @param response the search response
161282
* @param responseListener the listener to be notified when the response is processed
162283
* @throws IOException if an I/O error occurs during the rewriting process
163284
*/
@@ -168,27 +289,23 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
168289

169290
// hitCountInPredictions keeps track of the count of hit that have the required input fields for each round of prediction
170291
Map<Integer, Integer> hitCountInPredictions = new HashMap<>();
171-
if (!oneToOne) {
172-
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListenerManyToOne(
173-
response,
174-
responseListener,
175-
processInputMap,
176-
processOutputMap,
177-
hitCountInPredictions
178-
);
179292

180-
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListenerManyToOne(
181-
rewriteResponseListener,
182-
inputMapSize
183-
);
184-
SearchHit[] hits = response.getHits().getHits();
185-
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
186-
processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
187-
}
188-
} else {
189-
responseListener.onFailure(new IllegalArgumentException("one to one prediction is not supported yet."));
190-
}
293+
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListener(
294+
response,
295+
responseListener,
296+
processInputMap,
297+
processOutputMap,
298+
hitCountInPredictions
299+
);
191300

301+
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListener(
302+
rewriteResponseListener,
303+
inputMapSize
304+
);
305+
SearchHit[] hits = response.getHits().getHits();
306+
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
307+
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
308+
}
192309
}
193310

194311
/**
@@ -201,7 +318,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
201318
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
202319
* @throws IOException if an I/O error occurs during the prediction process
203320
*/
204-
private void processPredictionsManyToOne(
321+
private void processPredictions(
205322
SearchHit[] hits,
206323
List<Map<String, String>> processInputMap,
207324
int inputMapIndex,
@@ -242,7 +359,7 @@ private void processPredictionsManyToOne(
242359
Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName);
243360
if (documentValue != null) {
244361
// when not existed in the map, add into the modelInputParameters map
245-
updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);
362+
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
246363
}
247364
}
248365
} else { // when document does not contain the documentFieldName, skip when ignoreMissing
@@ -263,8 +380,7 @@ private void processPredictionsManyToOne(
263380
Object documentValue = entry.getValue();
264381

265382
// when not existed in the map, add into the modelInputParameters map
266-
updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);
267-
383+
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
268384
}
269385
}
270386
}
@@ -306,18 +422,28 @@ public void onFailure(Exception e) {
306422
});
307423
}
308424

309-
private void updateModelInputParametersManyToOne(
310-
Map<String, Object> modelInputParameters,
311-
String modelInputFieldName,
312-
Object documentValue
313-
) {
314-
if (!modelInputParameters.containsKey(modelInputFieldName)) {
315-
List<Object> documentValueList = new ArrayList<>();
316-
documentValueList.add(documentValue);
317-
modelInputParameters.put(modelInputFieldName, documentValueList);
425+
/**
426+
* Updates the model input parameters map with the given document value.
427+
* If the setting is one-to-one,
428+
* simply put the document value in the map
429+
* If the setting is many-to-one,
430+
* create a new list and add the document value
431+
* @param modelInputParameters The map containing the model input parameters.
432+
* @param modelInputFieldName The name of the model input field.
433+
* @param documentValue The value from the document that needs to be added to the model input parameters.
434+
*/
435+
private void updateModelInputParameters(Map<String, Object> modelInputParameters, String modelInputFieldName, Object documentValue) {
436+
if (!this.oneToOne) {
437+
if (!modelInputParameters.containsKey(modelInputFieldName)) {
438+
List<Object> documentValueList = new ArrayList<>();
439+
documentValueList.add(documentValue);
440+
modelInputParameters.put(modelInputFieldName, documentValueList);
441+
} else {
442+
List<Object> valueList = ((List) modelInputParameters.get(modelInputFieldName));
443+
valueList.add(documentValue);
444+
}
318445
} else {
319-
List<Object> valueList = ((List) modelInputParameters.get(modelInputFieldName));
320-
valueList.add(documentValue);
446+
modelInputParameters.put(modelInputFieldName, documentValue);
321447
}
322448
}
323449

@@ -328,7 +454,7 @@ private void updateModelInputParametersManyToOne(
328454
* @param inputMapSize the size of the input map
329455
* @return a grouped action listener for batch predictions
330456
*/
331-
private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListenerManyToOne(
457+
private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListener(
332458
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener,
333459
int inputMapSize
334460
) {
@@ -353,14 +479,14 @@ public void onFailure(Exception e) {
353479
/**
354480
* Creates an action listener for rewriting the response with the inference results.
355481
*
356-
* @param response the search response
357-
* @param responseListener the listener to be notified when the response is processed
358-
* @param processInputMap the list of input mappings
359-
* @param processOutputMap the list of output mappings
360-
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
482+
* @param response the search response
483+
* @param responseListener the listener to be notified when the response is processed
484+
* @param processInputMap the list of input mappings
485+
* @param processOutputMap the list of output mappings
486+
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
361487
* @return an action listener for rewriting the response with the inference results
362488
*/
363-
private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListenerManyToOne(
489+
private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListener(
364490
SearchResponse response,
365491
ActionListener<SearchResponse> responseListener,
366492
List<Map<String, String>> processInputMap,
@@ -392,7 +518,7 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
392518
Map<String, String> outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap);
393519

394520
boolean isModelInputMissing = false;
395-
if (processInputMap != null) {
521+
if (processInputMap != null && !processInputMap.isEmpty()) {
396522
isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
397523
}
398524
if (!isModelInputMissing) {
@@ -499,10 +625,10 @@ private boolean checkIsModelInputMissing(Map<String, Object> document, Map<Strin
499625
* <p>If the processOutputMap is not null and not empty, the mapping at the specified mappingIndex
500626
* is returned.
501627
*
502-
* @param mappingIndex the index of the mapping to retrieve from the processOutputMap
628+
* @param mappingIndex the index of the mapping to retrieve from the processOutputMap
503629
* @param processOutputMap the list of output mappings, can be null or empty
504630
* @return a Map containing the output mapping, either the default mapping or the mapping at the
505-
* specified index
631+
* specified index
506632
*/
507633
private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex, List<Map<String, String>> processOutputMap) {
508634
Map<String, String> outputMapping;
@@ -524,11 +650,11 @@ private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex,
524650
* <p>If the processInputMap is not null and not empty, the mapping at the specified mappingIndex
525651
* is returned.
526652
*
527-
* @param sourceAsMap the source map containing the input data
528-
* @param mappingIndex the index of the mapping to retrieve from the processInputMap
653+
* @param sourceAsMap the source map containing the input data
654+
* @param mappingIndex the index of the mapping to retrieve from the processInputMap
529655
* @param processInputMap the list of input mappings, can be null or empty
530656
* @return a Map containing the input mapping, either the mapping extracted from sourceAsMap or
531-
* the mapping at the specified index
657+
* the mapping at the specified index
532658
*/
533659
private static Map<String, String> getDefaultInputMapping(
534660
Map<String, Object> sourceAsMap,

0 commit comments

Comments
 (0)