Skip to content

Commit ac037de

Browse files
committed
implement batch document update scenario in text embedding processor
Signed-off-by: Will Hwang <sang7239@gmail.com>
1 parent 1a6e58e commit ac037de

10 files changed

+485
-83
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55

66
## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
77
### Features
8-
- Add Optimized Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
8+
- Optimizing embedding generation in text embedding processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
99
### Enhancements
1010
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
1111
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+122-39
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.apache.commons.lang3.StringUtils;
2727
import org.apache.commons.lang3.tuple.ImmutablePair;
2828
import org.apache.commons.lang3.tuple.Pair;
29+
import org.opensearch.action.get.MultiGetItemResponse;
30+
import org.opensearch.action.get.MultiGetRequest;
2931
import org.opensearch.common.collect.Tuple;
3032
import org.opensearch.core.action.ActionListener;
3133
import org.opensearch.core.common.util.CollectionUtils;
@@ -54,6 +56,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
5456

5557
public static final String MODEL_ID_FIELD = "model_id";
5658
public static final String FIELD_MAP_FIELD = "field_map";
59+
public static final String INDEX_FIELD = "_index";
60+
public static final String ID_FIELD = "_id";
5761
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
5862
if (v1 instanceof Collection && v2 instanceof Collection) {
5963
((Collection) v1).addAll((Collection) v2);
@@ -169,49 +173,91 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
169173
*/
170174
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
171175

176+
/**
177+
* This is the function which does actual inference work for subBatchExecute interface.
178+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
179+
* @param handler a callback handler to handle inference results which is a list of objects.
180+
*/
172181
@Override
173182
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
174-
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
175-
handler.accept(Collections.emptyList());
176-
return;
177-
}
183+
try {
184+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
185+
handler.accept(Collections.emptyList());
186+
return;
187+
}
178188

179-
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
180-
List<String> inferenceList = constructInferenceTexts(dataForInferences);
181-
if (inferenceList.isEmpty()) {
189+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
190+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
191+
if (inferenceList.isEmpty()) {
192+
handler.accept(ingestDocumentWrappers);
193+
return;
194+
}
195+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
196+
} catch (Exception e) {
197+
updateWithExceptions(ingestDocumentWrappers, e);
182198
handler.accept(ingestDocumentWrappers);
183-
return;
184199
}
185-
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
186-
inferenceList = sortedResult.v1();
187-
Map<Integer, Integer> originalOrder = sortedResult.v2();
188-
doBatchExecute(inferenceList, results -> {
189-
int startIndex = 0;
190-
results = restoreToOriginalOrder(results, originalOrder);
191-
for (DataForInference dataForInference : dataForInferences) {
192-
if (dataForInference.getIngestDocumentWrapper().getException() != null
193-
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
194-
continue;
200+
}
201+
202+
/**
203+
* This is a helper function for subBatchExecute, which invokes doBatchExecute for given inference list.
204+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
205+
* @param inferenceList a list of String for inference.
206+
* @param dataForInferences a list of data for inference, which includes ingestDocumentWrapper, processMap, inferenceList.
207+
* @param handler a callback handler to handle inference results which is a list of objects.
208+
*/
209+
protected void doSubBatchExecute(
210+
List<IngestDocumentWrapper> ingestDocumentWrappers,
211+
List<String> inferenceList,
212+
List<DataForInference> dataForInferences,
213+
Consumer<List<IngestDocumentWrapper>> handler
214+
) {
215+
try {
216+
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
217+
inferenceList = sortedResult.v1();
218+
Map<Integer, Integer> originalOrder = sortedResult.v2();
219+
doBatchExecute(inferenceList, results -> {
220+
try {
221+
int startIndex = 0;
222+
results = restoreToOriginalOrder(results, originalOrder);
223+
for (DataForInference dataForInference : dataForInferences) {
224+
if (dataForInference.getIngestDocumentWrapper().getException() != null
225+
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
226+
continue;
227+
}
228+
List<?> inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size());
229+
startIndex += dataForInference.getInferenceList().size();
230+
setVectorFieldsToDocument(
231+
dataForInference.getIngestDocumentWrapper().getIngestDocument(),
232+
dataForInference.getProcessMap(),
233+
inferenceResults
234+
);
235+
}
236+
handler.accept(ingestDocumentWrappers);
237+
} catch (Exception e) {
238+
updateWithExceptions(ingestDocumentWrappers, e);
239+
handler.accept(ingestDocumentWrappers);
195240
}
196-
List<?> inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size());
197-
startIndex += dataForInference.getInferenceList().size();
198-
setVectorFieldsToDocument(
199-
dataForInference.getIngestDocumentWrapper().getIngestDocument(),
200-
dataForInference.getProcessMap(),
201-
inferenceResults
202-
);
203-
}
204-
handler.accept(ingestDocumentWrappers);
205-
}, exception -> {
206-
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
207-
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
208-
// set exception to IngestDocumentWrapper which doesn't have exception before.
209-
if (ingestDocumentWrapper.getException() == null) {
210-
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
241+
}, exception -> {
242+
try {
243+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
244+
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
245+
// set exception to IngestDocumentWrapper which doesn't have exception before.
246+
if (ingestDocumentWrapper.getException() == null) {
247+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
248+
}
249+
}
250+
handler.accept(ingestDocumentWrappers);
251+
} catch (Exception e) {
252+
updateWithExceptions(ingestDocumentWrappers, e);
253+
handler.accept(ingestDocumentWrappers);
211254
}
212-
}
255+
256+
});
257+
} catch (Exception e) {
258+
updateWithExceptions(ingestDocumentWrappers, e);
213259
handler.accept(ingestDocumentWrappers);
214-
});
260+
}
215261
}
216262

217263
private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
@@ -238,7 +284,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238284
return sortedResults;
239285
}
240286

241-
private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
287+
protected List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
242288
List<String> inferenceTexts = new ArrayList<>();
243289
for (DataForInference dataForInference : dataForInferences) {
244290
if (dataForInference.getIngestDocumentWrapper().getException() != null
@@ -250,7 +296,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250296
return inferenceTexts;
251297
}
252298

253-
private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
299+
protected List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
254300
List<DataForInference> dataForInferences = new ArrayList<>();
255301
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
256302
Map<String, Object> processMap = null;
@@ -272,7 +318,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272318

273319
@Getter
274320
@AllArgsConstructor
275-
private static class DataForInference {
321+
protected static class DataForInference {
276322
private final IngestDocumentWrapper ingestDocumentWrapper;
277323
private final Map<String, Object> processMap;
278324
private final List<String> inferenceList;
@@ -415,6 +461,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415461
nlpResult.forEach(ingestDocument::setFieldValue);
416462
}
417463

464+
/**
465+
* This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
466+
* @param ingestDocumentWrappers, list of ingest documents
467+
* */
468+
protected MultiGetRequest buildMultiGetRequest(List<IngestDocumentWrapper> ingestDocumentWrappers) {
469+
MultiGetRequest multiGetRequest = new MultiGetRequest();
470+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
471+
Object index = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(INDEX_FIELD);
472+
Object id = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(ID_FIELD);
473+
if (Objects.nonNull(index) && Objects.nonNull(id)) {
474+
multiGetRequest.add(index.toString(), id.toString());
475+
}
476+
}
477+
return multiGetRequest;
478+
}
479+
480+
/**
481+
* This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
482+
* @param multiGetItemResponses, array of responses from Multi Get Request
483+
* */
484+
protected Map<String, Map<String, Object>> createDocumentMap(MultiGetItemResponse[] multiGetItemResponses) {
485+
Map<String, Map<String, Object>> existingDocuments = new HashMap<>();
486+
for (MultiGetItemResponse item : multiGetItemResponses) {
487+
String id = item.getId();
488+
Map<String, Object> existingDocument = item.getResponse().getSourceAsMap();
489+
existingDocuments.put(id, existingDocument);
490+
}
491+
return existingDocuments;
492+
}
493+
418494
@SuppressWarnings({ "unchecked" })
419495
@VisibleForTesting
420496
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
@@ -504,6 +580,13 @@ private void processMapEntryValue(
504580
}
505581
}
506582

583+
// This method updates each ingestDocument with exceptions
584+
protected void updateWithExceptions(List<IngestDocumentWrapper> ingestDocumentWrappers, Exception e) {
585+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
586+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
587+
}
588+
}
589+
507590
private void processMapEntryValue(
508591
List<?> results,
509592
IndexWrapper indexWrapper,
@@ -582,7 +665,7 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582665
List<Map<String, Object>> keyToResult = new ArrayList<>();
583666
sourceValue.stream()
584667
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
585-
// sourceValue has been filtered
668+
// sourceValue has been filtered
586669
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
587670
return keyToResult;
588671
}

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+92
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
*/
55
package org.opensearch.neuralsearch.processor;
66

7+
import java.util.ArrayList;
8+
import java.util.Collections;
79
import java.util.List;
810
import java.util.Map;
911
import java.util.Objects;
@@ -13,10 +15,14 @@
1315

1416
import org.opensearch.action.get.GetAction;
1517
import org.opensearch.action.get.GetRequest;
18+
import org.opensearch.action.get.MultiGetAction;
19+
import org.opensearch.action.get.MultiGetItemResponse;
1620
import org.opensearch.cluster.service.ClusterService;
1721
import org.opensearch.core.action.ActionListener;
22+
import org.opensearch.core.common.util.CollectionUtils;
1823
import org.opensearch.env.Environment;
1924
import org.opensearch.ingest.IngestDocument;
25+
import org.opensearch.ingest.IngestDocumentWrapper;
2026
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
2127

2228
import lombok.extern.log4j.Log4j2;
@@ -106,4 +112,90 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
106112
ActionListener.wrap(handler::accept, onException)
107113
);
108114
}
115+
116+
@Override
117+
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
118+
try {
119+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
120+
handler.accept(Collections.emptyList());
121+
return;
122+
}
123+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
124+
if (dataForInferences.isEmpty()) {
125+
handler.accept(ingestDocumentWrappers);
126+
return;
127+
}
128+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
129+
if (inferenceList.isEmpty()) {
130+
handler.accept(ingestDocumentWrappers);
131+
return;
132+
}
133+
// skip existing flag is turned off. Call doSubBatchExecute without filtering
134+
if (skipExisting == false) {
135+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
136+
return;
137+
}
138+
// skipExisting flag is turned on, eligible inference texts in dataForInferences will be compared and filtered after embeddings
139+
// are copied
140+
openSearchClient.execute(
141+
MultiGetAction.INSTANCE,
142+
buildMultiGetRequest(ingestDocumentWrappers),
143+
ActionListener.wrap(response -> {
144+
try {
145+
MultiGetItemResponse[] multiGetItemResponses = response.getResponses();
146+
if (multiGetItemResponses == null || multiGetItemResponses.length == 0) {
147+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
148+
return;
149+
}
150+
// create a map of documents with key: doc_id and value: doc
151+
Map<String, Map<String, Object>> existingDocuments = createDocumentMap(multiGetItemResponses);
152+
List<DataForInference> filteredDataForInference = filterDataForInference(dataForInferences, existingDocuments);
153+
List<String> filteredInferenceList = constructInferenceTexts(filteredDataForInference);
154+
if (filteredInferenceList.isEmpty()) {
155+
handler.accept(ingestDocumentWrappers);
156+
} else {
157+
doSubBatchExecute(ingestDocumentWrappers, filteredInferenceList, filteredDataForInference, handler);
158+
}
159+
} catch (Exception e) {
160+
updateWithExceptions(ingestDocumentWrappers, e);
161+
handler.accept(ingestDocumentWrappers);
162+
}
163+
}, e -> {
164+
// When exception is thrown in for MultiGetAction, set exception to all ingestDocumentWrappers
165+
updateWithExceptions(ingestDocumentWrappers, e);
166+
handler.accept(ingestDocumentWrappers);
167+
})
168+
);
169+
} catch (Exception e) {
170+
updateWithExceptions(ingestDocumentWrappers, e);
171+
handler.accept(ingestDocumentWrappers);
172+
}
173+
}
174+
175+
// This is a helper method to filter the given list of dataForInferences by comparing its documents with existingDocuments with
176+
// textEmbeddingInferenceFilter
177+
private List<DataForInference> filterDataForInference(
178+
List<DataForInference> dataForInferences,
179+
Map<String, Map<String, Object>> existingDocuments
180+
) {
181+
List<DataForInference> filteredDataForInference = new ArrayList<>();
182+
for (DataForInference dataForInference : dataForInferences) {
183+
IngestDocumentWrapper ingestDocumentWrapper = dataForInference.getIngestDocumentWrapper();
184+
Map<String, Object> processMap = dataForInference.getProcessMap();
185+
Map<String, Object> document = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata();
186+
Object id = document.get(ID_FIELD);
187+
// insert non-filtered dataForInference if existing document does not exist
188+
if (Objects.isNull(id) || existingDocuments.containsKey(id.toString()) == false) {
189+
filteredDataForInference.add(dataForInference);
190+
continue;
191+
}
192+
// filter dataForInference when existing document exists
193+
String docId = id.toString();
194+
Map<String, Object> existingDocument = existingDocuments.get(docId);
195+
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(existingDocument, document, processMap);
196+
List<String> filteredInferenceList = createInferenceList(filteredProcessMap);
197+
filteredDataForInference.add(new DataForInference(ingestDocumentWrapper, filteredProcessMap, filteredInferenceList));
198+
}
199+
return filteredDataForInference;
200+
}
109201
}

0 commit comments

Comments
 (0)