Skip to content

Commit afd1215

Browse files
authored
[FEATURE] Support batch ingestion in TextEmbeddingProcessor & SparseEncodingProcessor (#744)
* Support batch ingestion in TextEmbeddingProcess & SparseEncodingProcessor Signed-off-by: Liyun Xiu <xiliyun@amazon.com> * Update Changelog Signed-off-by: Liyun Xiu <xiliyun@amazon.com> * Add UT and IT Signed-off-by: Liyun Xiu <xiliyun@amazon.com> * Add comments Signed-off-by: Liyun Xiu <xiliyun@amazon.com> * Sort texts by length before sending for inference Signed-off-by: Liyun Xiu <xiliyun@amazon.com> * Make consistent check for inferenceList Signed-off-by: Liyun Xiu <xiliyun@amazon.com> --------- Signed-off-by: Liyun Xiu <xiliyun@amazon.com>
1 parent 0255bf0 commit afd1215

13 files changed

+642
-72
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1414

1515
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x)
1616
### Features
17+
- Support batchExecute in TextEmbeddingProcessor and SparseEncodingProcessor ([#743](https://github.com/opensearch-project/neural-search/issues/743))
1718
### Enhancements
1819
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
1920
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+125
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,30 @@
55
package org.opensearch.neuralsearch.processor;
66

77
import java.util.ArrayList;
8+
import java.util.Arrays;
9+
import java.util.Collections;
10+
import java.util.Comparator;
11+
import java.util.HashMap;
812
import java.util.LinkedHashMap;
913
import java.util.List;
1014
import java.util.Map;
1115
import java.util.Objects;
1216
import java.util.function.BiConsumer;
17+
import java.util.function.Consumer;
1318
import java.util.function.Supplier;
1419
import java.util.stream.Collectors;
1520
import java.util.stream.IntStream;
1621

22+
import lombok.AllArgsConstructor;
23+
import lombok.Getter;
1724
import org.apache.commons.lang3.StringUtils;
25+
import org.opensearch.common.collect.Tuple;
26+
import org.opensearch.core.common.util.CollectionUtils;
1827
import org.opensearch.env.Environment;
1928
import org.opensearch.index.mapper.MapperService;
2029
import org.opensearch.ingest.AbstractProcessor;
2130
import org.opensearch.ingest.IngestDocument;
31+
import org.opensearch.ingest.IngestDocumentWrapper;
2232
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
2333

2434
import com.google.common.annotations.VisibleForTesting;
@@ -119,6 +129,121 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
119129
}
120130
}
121131

132+
/**
133+
* This is the function which does actual inference work for batchExecute interface.
134+
* @param inferenceList a list of String for inference.
135+
* @param handler a callback handler to handle inference results which is a list of objects.
136+
* @param onException an exception callback to handle exception.
137+
*/
138+
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
139+
140+
@Override
141+
public void batchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
142+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
143+
handler.accept(Collections.emptyList());
144+
return;
145+
}
146+
147+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
148+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
149+
if (inferenceList.isEmpty()) {
150+
handler.accept(ingestDocumentWrappers);
151+
return;
152+
}
153+
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
154+
inferenceList = sortedResult.v1();
155+
Map<Integer, Integer> originalOrder = sortedResult.v2();
156+
doBatchExecute(inferenceList, results -> {
157+
int startIndex = 0;
158+
results = restoreToOriginalOrder(results, originalOrder);
159+
for (DataForInference dataForInference : dataForInferences) {
160+
if (dataForInference.getIngestDocumentWrapper().getException() != null
161+
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
162+
continue;
163+
}
164+
List<?> inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size());
165+
startIndex += dataForInference.getInferenceList().size();
166+
setVectorFieldsToDocument(
167+
dataForInference.getIngestDocumentWrapper().getIngestDocument(),
168+
dataForInference.getProcessMap(),
169+
inferenceResults
170+
);
171+
}
172+
handler.accept(ingestDocumentWrappers);
173+
}, exception -> {
174+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
175+
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
176+
// set exception to IngestDocumentWrapper which doesn't have exception before.
177+
if (ingestDocumentWrapper.getException() == null) {
178+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
179+
}
180+
}
181+
handler.accept(ingestDocumentWrappers);
182+
});
183+
}
184+
185+
private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
186+
List<Tuple<Integer, String>> docsWithIndex = new ArrayList<>();
187+
for (int i = 0; i < inferenceList.size(); ++i) {
188+
docsWithIndex.add(Tuple.tuple(i, inferenceList.get(i)));
189+
}
190+
docsWithIndex.sort(Comparator.comparingInt(t -> t.v2().length()));
191+
List<String> sortedInferenceList = docsWithIndex.stream().map(Tuple::v2).collect(Collectors.toList());
192+
Map<Integer, Integer> originalOrderMap = new HashMap<>();
193+
for (int i = 0; i < docsWithIndex.size(); ++i) {
194+
originalOrderMap.put(i, docsWithIndex.get(i).v1());
195+
}
196+
return Tuple.tuple(sortedInferenceList, originalOrderMap);
197+
}
198+
199+
private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> originalOrder) {
200+
List<Object> sortedResults = Arrays.asList(results.toArray());
201+
for (int i = 0; i < results.size(); ++i) {
202+
if (!originalOrder.containsKey(i)) continue;
203+
int oldIndex = originalOrder.get(i);
204+
sortedResults.set(oldIndex, results.get(i));
205+
}
206+
return sortedResults;
207+
}
208+
209+
private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
210+
List<String> inferenceTexts = new ArrayList<>();
211+
for (DataForInference dataForInference : dataForInferences) {
212+
if (dataForInference.getIngestDocumentWrapper().getException() != null
213+
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
214+
continue;
215+
}
216+
inferenceTexts.addAll(dataForInference.getInferenceList());
217+
}
218+
return inferenceTexts;
219+
}
220+
221+
private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
222+
List<DataForInference> dataForInferences = new ArrayList<>();
223+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
224+
Map<String, Object> processMap = null;
225+
List<String> inferenceList = null;
226+
try {
227+
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
228+
processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
229+
inferenceList = createInferenceList(processMap);
230+
} catch (Exception e) {
231+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
232+
} finally {
233+
dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList));
234+
}
235+
}
236+
return dataForInferences;
237+
}
238+
239+
@Getter
240+
@AllArgsConstructor
241+
private static class DataForInference {
242+
private final IngestDocumentWrapper ingestDocumentWrapper;
243+
private final Map<String, Object> processMap;
244+
private final List<String> inferenceList;
245+
}
246+
122247
@SuppressWarnings({ "unchecked" })
123248
private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
124249
List<String> texts = new ArrayList<>();

src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java

+10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.util.List;
88
import java.util.Map;
99
import java.util.function.BiConsumer;
10+
import java.util.function.Consumer;
1011

1112
import org.opensearch.core.action.ActionListener;
1213
import org.opensearch.env.Environment;
@@ -49,4 +50,13 @@ public void doExecute(
4950
handler.accept(ingestDocument, null);
5051
}, e -> { handler.accept(null, e); }));
5152
}
53+
54+
@Override
55+
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
56+
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
57+
this.modelId,
58+
inferenceList,
59+
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
60+
);
61+
}
5262
}

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.util.List;
88
import java.util.Map;
99
import java.util.function.BiConsumer;
10+
import java.util.function.Consumer;
1011

1112
import org.opensearch.core.action.ActionListener;
1213
import org.opensearch.env.Environment;
@@ -48,4 +49,9 @@ public void doExecute(
4849
handler.accept(ingestDocument, null);
4950
}, e -> { handler.accept(null, e); }));
5051
}
52+
53+
@Override
54+
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
55+
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException));
56+
}
5157
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import com.google.common.collect.ImmutableList;
8+
import org.opensearch.ingest.IngestDocument;
9+
import org.opensearch.ingest.IngestDocumentWrapper;
10+
import org.opensearch.test.OpenSearchTestCase;
11+
12+
import java.util.ArrayList;
13+
import java.util.HashMap;
14+
import java.util.List;
15+
import java.util.Map;
16+
17+
public class InferenceProcessorTestCase extends OpenSearchTestCase {
18+
19+
protected List<IngestDocumentWrapper> createIngestDocumentWrappers(int count) {
20+
List<IngestDocumentWrapper> wrapperList = new ArrayList<>();
21+
for (int i = 0; i < count; ++i) {
22+
Map<String, Object> sourceAndMetadata = new HashMap<>();
23+
sourceAndMetadata.put("key1", "value1");
24+
wrapperList.add(new IngestDocumentWrapper(i, new IngestDocument(sourceAndMetadata, new HashMap<>()), null));
25+
}
26+
return wrapperList;
27+
}
28+
29+
protected List<List<Float>> createMockVectorWithLength(int size) {
30+
float suffix = .234f;
31+
List<List<Float>> result = new ArrayList<>();
32+
for (int i = 0; i < size * 2;) {
33+
List<Float> number = new ArrayList<>();
34+
number.add(i++ + suffix);
35+
number.add(i++ + suffix);
36+
result.add(number);
37+
}
38+
return result;
39+
}
40+
41+
protected List<List<Float>> createMockVectorResult() {
42+
List<List<Float>> modelTensorList = new ArrayList<>();
43+
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);
44+
List<Float> number2 = ImmutableList.of(3.234f, 4.354f);
45+
List<Float> number3 = ImmutableList.of(5.234f, 6.354f);
46+
List<Float> number4 = ImmutableList.of(7.234f, 8.354f);
47+
List<Float> number5 = ImmutableList.of(9.234f, 10.354f);
48+
List<Float> number6 = ImmutableList.of(11.234f, 12.354f);
49+
List<Float> number7 = ImmutableList.of(13.234f, 14.354f);
50+
modelTensorList.add(number1);
51+
modelTensorList.add(number2);
52+
modelTensorList.add(number3);
53+
modelTensorList.add(number4);
54+
modelTensorList.add(number5);
55+
modelTensorList.add(number6);
56+
modelTensorList.add(number7);
57+
return modelTensorList;
58+
}
59+
}

0 commit comments

Comments
 (0)