Skip to content

Commit a24646c

Browse files
Unify precomputation of aggregations behind a common API (#16733) (#17197)
We've had a series of aggregation speedups that use the same strategy: instead of iterating through documents that match the query one-by-one, we can look at a Lucene segment and compute the aggregation directly (if some particular conditions are met). In every case, we've hooked that into custom logic hijacks the getLeafCollector method and throws CollectionTerminatedException. This creates the illusion that we're implementing a custom LeafCollector, when really we're not collecting at all (which is the whole point). With this refactoring, the mechanism (hijacking getLeafCollector) is moved into AggregatorBase. Aggregators that have a strategy to precompute their answer can override tryPrecomputeAggregationForLeaf, which is expected to return true if they managed to precompute. This should also make it easier to keep track of which aggregations have precomputation approaches (since they override this method). --------- (cherry picked from commit 2847695) Signed-off-by: Michael Froh <froh@amazon.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent ffed717 commit a24646c

11 files changed

+168
-132
lines changed

server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
package org.opensearch.search.aggregations;
3333

3434
import org.apache.lucene.index.LeafReaderContext;
35+
import org.apache.lucene.search.CollectionTerminatedException;
3536
import org.apache.lucene.search.MatchAllDocsQuery;
3637
import org.apache.lucene.search.ScoreMode;
3738
import org.opensearch.core.common.breaker.CircuitBreaker;
@@ -200,6 +201,9 @@ public Map<String, Object> metadata() {
200201

201202
@Override
202203
public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
204+
if (tryPrecomputeAggregationForLeaf(ctx)) {
205+
throw new CollectionTerminatedException();
206+
}
203207
preGetSubLeafCollectors(ctx);
204208
final LeafBucketCollector sub = collectableSubAggregators.getLeafCollector(ctx);
205209
return getLeafCollector(ctx, sub);
@@ -216,6 +220,21 @@ protected void preGetSubLeafCollectors(LeafReaderContext ctx) throws IOException
216220
*/
217221
protected void doPreCollection() throws IOException {}
218222

223+
/**
224+
* Subclasses may override this method if they have an efficient way of computing their aggregation for the given
225+
* segment (versus collecting matching documents). If this method returns true, collection for the given segment
226+
* will be terminated, rather than executing normally.
227+
* <p>
228+
* If this method returns true, the aggregator's state should be identical to what it would be if matching
229+
* documents from the segment were fully collected. If this method returns false, the aggregator's state should
230+
* be unchanged from before this method is called.
231+
* @param ctx the context for the given segment
232+
* @return true if and only if results for this segment have been precomputed
233+
*/
234+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
235+
return false;
236+
}
237+
219238
@Override
220239
public final void preCollection() throws IOException {
221240
List<BucketCollector> collectors = Arrays.asList(subAggregators);
@@ -251,8 +270,8 @@ public Aggregator[] subAggregators() {
251270
public Aggregator subAggregator(String aggName) {
252271
if (subAggregatorbyName == null) {
253272
subAggregatorbyName = new HashMap<>(subAggregators.length);
254-
for (int i = 0; i < subAggregators.length; i++) {
255-
subAggregatorbyName.put(subAggregators[i].name(), subAggregators[i]);
273+
for (Aggregator subAggregator : subAggregators) {
274+
subAggregatorbyName.put(subAggregator.name(), subAggregator);
256275
}
257276
}
258277
return subAggregatorbyName.get(aggName);

server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,13 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
556556
}
557557

558558
@Override
559-
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
560-
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
561-
if (optimized) throw new CollectionTerminatedException();
559+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
560+
finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
561+
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
562+
}
562563

564+
@Override
565+
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
563566
finishLeaf();
564567

565568
boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR;

server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java

+11-11
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
import org.apache.lucene.index.LeafReaderContext;
3535
import org.apache.lucene.index.SortedNumericDocValues;
36-
import org.apache.lucene.search.CollectionTerminatedException;
3736
import org.apache.lucene.search.DocIdSetIterator;
3837
import org.apache.lucene.search.ScoreMode;
3938
import org.apache.lucene.util.CollectionUtil;
@@ -187,22 +186,23 @@ public ScoreMode scoreMode() {
187186
}
188187

189188
@Override
190-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
191-
if (valuesSource == null) {
192-
return LeafBucketCollector.NO_OP_COLLECTOR;
193-
}
194-
195-
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
196-
if (optimized) throw new CollectionTerminatedException();
197-
198-
SortedNumericDocValues values = valuesSource.longValues(ctx);
189+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
199190
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
200191
if (supportedStarTree != null) {
201192
if (preComputeWithStarTree(ctx, supportedStarTree) == true) {
202-
throw new CollectionTerminatedException();
193+
return true;
203194
}
204195
}
196+
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
197+
}
205198

199+
@Override
200+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
201+
if (valuesSource == null) {
202+
return LeafBucketCollector.NO_OP_COLLECTOR;
203+
}
204+
205+
SortedNumericDocValues values = valuesSource.longValues(ctx);
206206
return new LeafBucketCollectorBase(sub, values) {
207207
@Override
208208
public void collect(int doc, long owningBucketOrd) throws IOException {

server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
package org.opensearch.search.aggregations.bucket.range;
3333

3434
import org.apache.lucene.index.LeafReaderContext;
35-
import org.apache.lucene.search.CollectionTerminatedException;
3635
import org.apache.lucene.search.ScoreMode;
3736
import org.opensearch.core.ParseField;
3837
import org.opensearch.core.common.io.stream.StreamInput;
@@ -310,10 +309,15 @@ public ScoreMode scoreMode() {
310309
}
311310

312311
@Override
313-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
314-
if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) {
315-
throw new CollectionTerminatedException();
312+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
313+
if (segmentMatchAll(context, ctx)) {
314+
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
316315
}
316+
return false;
317+
}
318+
319+
@Override
320+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
317321

318322
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
319323
return new LeafBucketCollectorBase(sub, values) {

server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java

+36-36
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.apache.lucene.index.SortedSetDocValues;
4141
import org.apache.lucene.index.Terms;
4242
import org.apache.lucene.index.TermsEnum;
43-
import org.apache.lucene.search.CollectionTerminatedException;
4443
import org.apache.lucene.search.Weight;
4544
import org.apache.lucene.util.ArrayUtil;
4645
import org.apache.lucene.util.BytesRef;
@@ -166,35 +165,32 @@ public void setWeight(Weight weight) {
166165
@return A LeafBucketCollector implementation with collection termination, since collection is complete
167166
@throws IOException If an I/O error occurs during reading
168167
*/
169-
LeafBucketCollector termDocFreqCollector(
170-
LeafReaderContext ctx,
171-
SortedSetDocValues globalOrds,
172-
BiConsumer<Long, Integer> ordCountConsumer
173-
) throws IOException {
168+
boolean tryCollectFromTermFrequencies(LeafReaderContext ctx, SortedSetDocValues globalOrds, BiConsumer<Long, Integer> ordCountConsumer)
169+
throws IOException {
174170
if (weight == null) {
175171
// Weight not assigned - cannot use this optimization
176-
return null;
172+
return false;
177173
} else {
178174
if (weight.count(ctx) == 0) {
179175
// No documents matches top level query on this segment, we can skip the segment entirely
180-
return LeafBucketCollector.NO_OP_COLLECTOR;
176+
return true;
181177
} else if (weight.count(ctx) != ctx.reader().maxDoc()) {
182178
// weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and
183179
// top-level query matches all docs in the segment
184-
return null;
180+
return false;
185181
}
186182
}
187183

188184
Terms segmentTerms = ctx.reader().terms(this.fieldName);
189185
if (segmentTerms == null) {
190186
// Field is not indexed.
191-
return null;
187+
return false;
192188
}
193189

194190
NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME);
195191
if (docCountValues.nextDoc() != NO_MORE_DOCS) {
196192
// This segment has at least one document with the _doc_count field.
197-
return null;
193+
return false;
198194
}
199195

200196
TermsEnum indexTermsEnum = segmentTerms.iterator();
@@ -218,31 +214,28 @@ LeafBucketCollector termDocFreqCollector(
218214
ordinalTerm = globalOrdinalTermsEnum.next();
219215
}
220216
}
221-
return new LeafBucketCollector() {
222-
@Override
223-
public void collect(int doc, long owningBucketOrd) throws IOException {
224-
throw new CollectionTerminatedException();
225-
}
226-
};
217+
return true;
227218
}
228219

229220
@Override
230-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
221+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
231222
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
232-
collectionStrategy.globalOrdsReady(globalOrds);
233-
234223
if (collectionStrategy instanceof DenseGlobalOrds
235224
&& this.resultStrategy instanceof StandardTermsResults
236-
&& sub == LeafBucketCollector.NO_OP_COLLECTOR) {
237-
LeafBucketCollector termDocFreqCollector = termDocFreqCollector(
225+
&& subAggregators.length == 0) {
226+
return tryCollectFromTermFrequencies(
238227
ctx,
239228
globalOrds,
240229
(ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount)
241230
);
242-
if (termDocFreqCollector != null) {
243-
return termDocFreqCollector;
244-
}
245231
}
232+
return false;
233+
}
234+
235+
@Override
236+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
237+
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
238+
collectionStrategy.globalOrdsReady(globalOrds);
246239

247240
SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds);
248241
if (singleValues != null) {
@@ -433,6 +426,24 @@ static class LowCardinality extends GlobalOrdinalsStringTermsAggregator {
433426
this.segmentDocCounts = context.bigArrays().newLongArray(1, true);
434427
}
435428

429+
@Override
430+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
431+
if (subAggregators.length == 0) {
432+
if (mapping != null) {
433+
mapSegmentCountsToGlobalCounts(mapping);
434+
}
435+
final SortedSetDocValues segmentOrds = valuesSource.ordinalsValues(ctx);
436+
segmentDocCounts = context.bigArrays().grow(segmentDocCounts, 1 + segmentOrds.getValueCount());
437+
mapping = valuesSource.globalOrdinalsMapping(ctx);
438+
return tryCollectFromTermFrequencies(
439+
ctx,
440+
segmentOrds,
441+
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
442+
);
443+
}
444+
return false;
445+
}
446+
436447
@Override
437448
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
438449
if (mapping != null) {
@@ -443,17 +454,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
443454
assert sub == LeafBucketCollector.NO_OP_COLLECTOR;
444455
mapping = valuesSource.globalOrdinalsMapping(ctx);
445456

446-
if (this.resultStrategy instanceof StandardTermsResults) {
447-
LeafBucketCollector termDocFreqCollector = this.termDocFreqCollector(
448-
ctx,
449-
segmentOrds,
450-
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
451-
);
452-
if (termDocFreqCollector != null) {
453-
return termDocFreqCollector;
454-
}
455-
}
456-
457457
final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds);
458458
if (singleValues != null) {
459459
segmentsWithSingleValuedOrds++;

server/src/main/java/org/opensearch/search/aggregations/metrics/AvgAggregator.java

+13-15
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
package org.opensearch.search.aggregations.metrics;
3333

3434
import org.apache.lucene.index.LeafReaderContext;
35-
import org.apache.lucene.search.CollectionTerminatedException;
3635
import org.apache.lucene.search.DocIdSetIterator;
3736
import org.apache.lucene.search.ScoreMode;
3837
import org.apache.lucene.util.FixedBitSet;
@@ -104,23 +103,29 @@ public ScoreMode scoreMode() {
104103
}
105104

106105
@Override
107-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
106+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
108107
if (valuesSource == null) {
109-
return LeafBucketCollector.NO_OP_COLLECTOR;
108+
return false;
110109
}
111110
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
112111
if (supportedStarTree != null) {
113112
if (parent != null && subAggregators.length == 0) {
114113
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
115114
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
116-
return LeafBucketCollector.NO_OP_COLLECTOR;
115+
return true;
117116
}
118-
return getStarTreeLeafCollector(ctx, sub, supportedStarTree);
117+
precomputeLeafUsingStarTree(ctx, supportedStarTree);
118+
return true;
119119
}
120-
return getDefaultLeafCollector(ctx, sub);
120+
return false;
121121
}
122122

123-
private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
123+
@Override
124+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
125+
if (valuesSource == null) {
126+
return LeafBucketCollector.NO_OP_COLLECTOR;
127+
}
128+
124129
final BigArrays bigArrays = context.bigArrays();
125130
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
126131
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
@@ -154,8 +159,7 @@ public void collect(int doc, long bucket) throws IOException {
154159
};
155160
}
156161

157-
public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
158-
throws IOException {
162+
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
159163
StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
160164
assert starTreeValues != null;
161165

@@ -200,12 +204,6 @@ public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafB
200204

201205
sums.set(0, kahanSummation.value());
202206
compensations.set(0, kahanSummation.delta());
203-
return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) {
204-
@Override
205-
public void collect(int doc, long bucket) {
206-
throw new CollectionTerminatedException();
207-
}
208-
};
209207
}
210208

211209
@Override

server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java

+20-16
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,24 @@ public ScoreMode scoreMode() {
104104
return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
105105
}
106106

107+
@Override
108+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
109+
if (valuesSource == null) {
110+
return false;
111+
}
112+
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
113+
if (supportedStarTree != null) {
114+
if (parent != null && subAggregators.length == 0) {
115+
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
116+
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
117+
return true;
118+
}
119+
precomputeLeafUsingStarTree(ctx, supportedStarTree);
120+
return true;
121+
}
122+
return false;
123+
}
124+
107125
@Override
108126
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
109127
if (valuesSource == null) {
@@ -130,20 +148,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
130148
}
131149
}
132150

133-
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
134-
if (supportedStarTree != null) {
135-
if (parent != null && subAggregators.length == 0) {
136-
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
137-
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
138-
return LeafBucketCollector.NO_OP_COLLECTOR;
139-
}
140-
getStarTreeCollector(ctx, sub, supportedStarTree);
141-
}
142-
return getDefaultLeafCollector(ctx, sub);
143-
}
144-
145-
private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
146-
147151
final BigArrays bigArrays = context.bigArrays();
148152
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
149153
final NumericDoubleValues values = MultiValueMode.MAX.select(allValues);
@@ -167,9 +171,9 @@ public void collect(int doc, long bucket) throws IOException {
167171
};
168172
}
169173

170-
public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException {
174+
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
171175
AtomicReference<Double> max = new AtomicReference<>(maxes.get(0));
172-
StarTreeQueryHelper.getStarTreeLeafCollector(context, valuesSource, ctx, sub, starTree, MetricStat.MAX.getTypeName(), value -> {
176+
StarTreeQueryHelper.precomputeLeafUsingStarTree(context, valuesSource, ctx, starTree, MetricStat.MAX.getTypeName(), value -> {
173177
max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value))));
174178
}, () -> maxes.set(0, max.get()));
175179
}

0 commit comments

Comments
 (0)