Skip to content

Commit 4400154

Browse files
committed
implement batch document update scenario for text embedding processor
Signed-off-by: Will Hwang <sang7239@gmail.com>
1 parent 1a6e58e commit 4400154

11 files changed

+481
-95
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55

66
## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
77
### Features
8-
- Add Optimized Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
8+
- Optimizing embedding generation in text embedding processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
99
### Enhancements
1010
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
1111
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))

DEVELOPER_GUIDE.md

+12-12
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ through the same build issue.
351351

352352
### Class and package names
353353

354-
Class names should use `CamelCase`.
354+
Class names should use `CamelCase`.
355355

356-
Try to put new classes into existing packages if package name abstracts the purpose of the class.
356+
Try to put new classes into existing packages if package name abstracts the purpose of the class.
357357

358358
Example of good class file name and package utilization:
359359

@@ -371,7 +371,7 @@ methods rather than a long single one and does everything.
371371
### Documentation
372372

373373
Document you code. That includes purpose of new classes, every public method and code sections that have critical or non-trivial
374-
logic (check this example https://github.com/opensearch-project/neural-search/blob/main/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L238).
374+
logic (check this example https://github.com/opensearch-project/neural-search/blob/main/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L238).
375375

376376
When you submit a feature PR, please submit a new
377377
[documentation issue](https://github.com/opensearch-project/documentation-website/issues/new/choose). This is a path for the documentation to be published as part of https://opensearch.org/docs/latest/ documentation site.
@@ -384,17 +384,17 @@ For the most part, we're using common conventions for Java projects. Here are a
384384

385385
1. Use descriptive names for classes, methods, fields, and variables.
386386
2. Avoid abbreviations unless they are widely accepted
387-
3. Use `final` on all method arguments unless it's absolutely necessary
387+
3. Use `final` on all method arguments unless it's absolutely necessary
388388
4. Wildcard imports are not allowed.
389389
5. Static imports are preferred over qualified imports when using static methods
390390
6. Prefer creating non-static public methods whenever possible. Avoid static methods in general, as they can often serve as shortcuts.
391391
Static methods are acceptable if they are private and do not access class state.
392-
7. Use functional programming style inside methods unless it's a performance critical section.
392+
7. Use functional programming style inside methods unless it's a performance critical section.
393393
8. For parameters of lambda expression please use meaningful names instead of shorten cryptic ones.
394394
9. Use Optional for return values if the value may not be present. This should be preferred to returning null.
395395
10. Do not create checked exceptions, and do not throw checked exceptions from public methods whenever possible. In general, if you call a method with a checked exception, you should wrap that exception into an unchecked exception.
396396
11. Throwing checked exceptions from private methods is acceptable.
397-
12. Use String.format when a string includes parameters, and prefer this over direct string concatenation. Always specify a Locale with String.format;
397+
12. Use String.format when a string includes parameters, and prefer this over direct string concatenation. Always specify a Locale with String.format;
398398
as a rule of thumb, use Locale.ROOT.
399399
13. Prefer Lombok annotations to the manually written boilerplate code
400400
14. When throwing an exception, avoid including user-provided content in the exception message. For secure coding practices,
@@ -440,17 +440,17 @@ Fix any new warnings before submitting your PR to ensure proper code documentati
440440

441441
### Tests
442442

443-
Write unit and integration tests for your new functionality.
443+
Write unit and integration tests for your new functionality.
444444

445445
Unit tests are preferred as they are cheap and fast, try to use them to cover all possible
446-
combinations of parameters. Utilize mocks to mimic dependencies.
446+
combinations of parameters. Utilize mocks to mimic dependencies.
447447

448-
Integration tests should be used sparingly, focusing primarily on the main (happy path) scenario or cases where extensive
449-
mocking is impractical. Include one or two unhappy paths to confirm that correct response codes are returned to the user.
450-
Whenever possible, favor scenarios that do not require model deployment. If model deployment is necessary, use an existing
448+
Integration tests should be used sparingly, focusing primarily on the main (happy path) scenario or cases where extensive
449+
mocking is impractical. Include one or two unhappy paths to confirm that correct response codes are returned to the user.
450+
Whenever possible, favor scenarios that do not require model deployment. If model deployment is necessary, use an existing
451451
model, as tests involving new model deployments are the most resource-intensive.
452452

453-
If your changes could affect backward compatibility, please include relevant backward compatibility tests along with your
453+
If your changes could affect backward compatibility, please include relevant backward compatibility tests along with your
454454
PR. For guidance on adding these tests, refer to the [Backwards Compatibility Testing](#backwards-compatibility-testing) section in this guide.
455455

456456
### Outdated or irrelevant code

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+111-39
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.apache.commons.lang3.StringUtils;
2727
import org.apache.commons.lang3.tuple.ImmutablePair;
2828
import org.apache.commons.lang3.tuple.Pair;
29+
import org.opensearch.action.get.MultiGetItemResponse;
30+
import org.opensearch.action.get.MultiGetRequest;
2931
import org.opensearch.common.collect.Tuple;
3032
import org.opensearch.core.action.ActionListener;
3133
import org.opensearch.core.common.util.CollectionUtils;
@@ -54,6 +56,8 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
5456

5557
public static final String MODEL_ID_FIELD = "model_id";
5658
public static final String FIELD_MAP_FIELD = "field_map";
59+
public static final String INDEX_FIELD = "_index";
60+
public static final String ID_FIELD = "_id";
5761
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
5862
if (v1 instanceof Collection && v2 instanceof Collection) {
5963
((Collection) v1).addAll((Collection) v2);
@@ -169,49 +173,80 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
169173
*/
170174
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
171175

176+
/**
177+
* This is the function which does actual inference work for subBatchExecute interface.
178+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
179+
* @param handler a callback handler to handle inference results which is a list of objects.
180+
*/
172181
@Override
173182
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
174-
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
175-
handler.accept(Collections.emptyList());
176-
return;
177-
}
183+
try {
184+
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
185+
handler.accept(Collections.emptyList());
186+
return;
187+
}
178188

179-
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
180-
List<String> inferenceList = constructInferenceTexts(dataForInferences);
181-
if (inferenceList.isEmpty()) {
189+
List<DataForInference> dataForInferences = getDataForInference(ingestDocumentWrappers);
190+
List<String> inferenceList = constructInferenceTexts(dataForInferences);
191+
if (inferenceList.isEmpty()) {
192+
handler.accept(ingestDocumentWrappers);
193+
return;
194+
}
195+
doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
196+
} catch (Exception e) {
197+
updateWithExceptions(ingestDocumentWrappers, e);
182198
handler.accept(ingestDocumentWrappers);
183-
return;
184199
}
185-
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
186-
inferenceList = sortedResult.v1();
187-
Map<Integer, Integer> originalOrder = sortedResult.v2();
188-
doBatchExecute(inferenceList, results -> {
189-
int startIndex = 0;
190-
results = restoreToOriginalOrder(results, originalOrder);
191-
for (DataForInference dataForInference : dataForInferences) {
192-
if (dataForInference.getIngestDocumentWrapper().getException() != null
193-
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
194-
continue;
200+
}
201+
202+
/**
203+
* This is a helper function for subBatchExecute, which invokes doBatchExecute for given inference list.
204+
* @param ingestDocumentWrappers a list of IngestDocuments in a batch.
205+
* @param inferenceList a list of String for inference.
206+
* @param dataForInferences a list of data for inference, which includes ingestDocumentWrapper, processMap, inferenceList.
207+
* @param handler a callback handler to handle inference results which is a list of objects.
208+
*/
209+
protected void doSubBatchExecute(
210+
List<IngestDocumentWrapper> ingestDocumentWrappers,
211+
List<String> inferenceList,
212+
List<DataForInference> dataForInferences,
213+
Consumer<List<IngestDocumentWrapper>> handler
214+
) {
215+
try {
216+
Tuple<List<String>, Map<Integer, Integer>> sortedResult = sortByLengthAndReturnOriginalOrder(inferenceList);
217+
inferenceList = sortedResult.v1();
218+
Map<Integer, Integer> originalOrder = sortedResult.v2();
219+
doBatchExecute(inferenceList, results -> {
220+
int startIndex = 0;
221+
results = restoreToOriginalOrder(results, originalOrder);
222+
for (DataForInference dataForInference : dataForInferences) {
223+
if (dataForInference.getIngestDocumentWrapper().getException() != null
224+
|| CollectionUtils.isEmpty(dataForInference.getInferenceList())) {
225+
continue;
226+
}
227+
List<?> inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size());
228+
startIndex += dataForInference.getInferenceList().size();
229+
setVectorFieldsToDocument(
230+
dataForInference.getIngestDocumentWrapper().getIngestDocument(),
231+
dataForInference.getProcessMap(),
232+
inferenceResults
233+
);
195234
}
196-
List<?> inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size());
197-
startIndex += dataForInference.getInferenceList().size();
198-
setVectorFieldsToDocument(
199-
dataForInference.getIngestDocumentWrapper().getIngestDocument(),
200-
dataForInference.getProcessMap(),
201-
inferenceResults
202-
);
203-
}
204-
handler.accept(ingestDocumentWrappers);
205-
}, exception -> {
206-
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
207-
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
208-
// set exception to IngestDocumentWrapper which doesn't have exception before.
209-
if (ingestDocumentWrapper.getException() == null) {
210-
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
235+
handler.accept(ingestDocumentWrappers);
236+
}, exception -> {
237+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
238+
// The IngestDocumentWrapper might already run into exception and not sent for inference. So here we only
239+
// set exception to IngestDocumentWrapper which doesn't have exception before.
240+
if (ingestDocumentWrapper.getException() == null) {
241+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
242+
}
211243
}
212-
}
244+
handler.accept(ingestDocumentWrappers);
245+
});
246+
} catch (Exception e) {
247+
updateWithExceptions(ingestDocumentWrappers, e);
213248
handler.accept(ingestDocumentWrappers);
214-
});
249+
}
215250
}
216251

217252
private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
@@ -238,7 +273,7 @@ private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> or
238273
return sortedResults;
239274
}
240275

241-
private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
276+
protected List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
242277
List<String> inferenceTexts = new ArrayList<>();
243278
for (DataForInference dataForInference : dataForInferences) {
244279
if (dataForInference.getIngestDocumentWrapper().getException() != null
@@ -250,7 +285,7 @@ private List<String> constructInferenceTexts(List<DataForInference> dataForInfer
250285
return inferenceTexts;
251286
}
252287

253-
private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
288+
protected List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
254289
List<DataForInference> dataForInferences = new ArrayList<>();
255290
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
256291
Map<String, Object> processMap = null;
@@ -272,7 +307,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
272307

273308
@Getter
274309
@AllArgsConstructor
275-
private static class DataForInference {
310+
protected static class DataForInference {
276311
private final IngestDocumentWrapper ingestDocumentWrapper;
277312
private final Map<String, Object> processMap;
278313
private final List<String> inferenceList;
@@ -415,6 +450,36 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<Stri
415450
nlpResult.forEach(ingestDocument::setFieldValue);
416451
}
417452

453+
/**
454+
* This method creates a MultiGetRequest from a list of ingest documents to be fetched for comparison
455+
* @param ingestDocumentWrappers, list of ingest documents
456+
* */
457+
protected MultiGetRequest buildMultiGetRequest(List<IngestDocumentWrapper> ingestDocumentWrappers) {
458+
MultiGetRequest multiGetRequest = new MultiGetRequest();
459+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
460+
Object index = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(INDEX_FIELD);
461+
Object id = ingestDocumentWrapper.getIngestDocument().getSourceAndMetadata().get(ID_FIELD);
462+
if (Objects.nonNull(index) && Objects.nonNull(id)) {
463+
multiGetRequest.add(index.toString(), id.toString());
464+
}
465+
}
466+
return multiGetRequest;
467+
}
468+
469+
/**
470+
* This method creates a map of documents from MultiGetItemResponse where the key is document ID and value is corresponding document
471+
* @param multiGetItemResponses, array of responses from Multi Get Request
472+
* */
473+
protected Map<String, Map<String, Object>> createDocumentMap(MultiGetItemResponse[] multiGetItemResponses) {
474+
Map<String, Map<String, Object>> existingDocuments = new HashMap<>();
475+
for (MultiGetItemResponse item : multiGetItemResponses) {
476+
String id = item.getId();
477+
Map<String, Object> existingDocument = item.getResponse().getSourceAsMap();
478+
existingDocuments.put(id, existingDocument);
479+
}
480+
return existingDocuments;
481+
}
482+
418483
@SuppressWarnings({ "unchecked" })
419484
@VisibleForTesting
420485
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
@@ -504,6 +569,13 @@ private void processMapEntryValue(
504569
}
505570
}
506571

572+
// This method updates each ingestDocument with exceptions
573+
protected void updateWithExceptions(List<IngestDocumentWrapper> ingestDocumentWrappers, Exception e) {
574+
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
575+
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
576+
}
577+
}
578+
507579
private void processMapEntryValue(
508580
List<?> results,
509581
IndexWrapper indexWrapper,
@@ -582,7 +654,7 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
582654
List<Map<String, Object>> keyToResult = new ArrayList<>();
583655
sourceValue.stream()
584656
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
585-
// sourceValue has been filtered
657+
// sourceValue has been filtered
586658
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
587659
return keyToResult;
588660
}

0 commit comments

Comments
 (0)