32
32
package org .opensearch .search .aggregations .bucket .range ;
33
33
34
34
import org .apache .lucene .index .LeafReaderContext ;
35
+ import org .apache .lucene .search .DocIdSetIterator ;
35
36
import org .apache .lucene .search .ScoreMode ;
37
+ import org .apache .lucene .util .FixedBitSet ;
36
38
import org .opensearch .core .ParseField ;
37
39
import org .opensearch .core .common .io .stream .StreamInput ;
38
40
import org .opensearch .core .common .io .stream .StreamOutput ;
43
45
import org .opensearch .core .xcontent .ToXContentObject ;
44
46
import org .opensearch .core .xcontent .XContentBuilder ;
45
47
import org .opensearch .core .xcontent .XContentParser ;
48
+ import org .opensearch .index .codec .composite .CompositeIndexFieldInfo ;
49
+ import org .opensearch .index .compositeindex .datacube .MetricStat ;
50
+ import org .opensearch .index .compositeindex .datacube .startree .index .StarTreeValues ;
51
+ import org .opensearch .index .compositeindex .datacube .startree .utils .StarTreeUtils ;
52
+ import org .opensearch .index .compositeindex .datacube .startree .utils .iterator .SortedNumericStarTreeValuesIterator ;
46
53
import org .opensearch .index .fielddata .SortedNumericDoubleValues ;
54
+ import org .opensearch .index .mapper .NumberFieldMapper ;
47
55
import org .opensearch .search .DocValueFormat ;
48
56
import org .opensearch .search .aggregations .Aggregator ;
49
57
import org .opensearch .search .aggregations .AggregatorFactories ;
53
61
import org .opensearch .search .aggregations .LeafBucketCollector ;
54
62
import org .opensearch .search .aggregations .LeafBucketCollectorBase ;
55
63
import org .opensearch .search .aggregations .NonCollectingAggregator ;
64
+ import org .opensearch .search .aggregations .StarTreeBucketCollector ;
65
+ import org .opensearch .search .aggregations .StarTreePreComputeCollector ;
56
66
import org .opensearch .search .aggregations .bucket .BucketsAggregator ;
57
67
import org .opensearch .search .aggregations .bucket .filterrewrite .FilterRewriteOptimizationContext ;
58
68
import org .opensearch .search .aggregations .bucket .filterrewrite .RangeAggregatorBridge ;
59
69
import org .opensearch .search .aggregations .support .ValuesSource ;
60
70
import org .opensearch .search .aggregations .support .ValuesSourceConfig ;
61
71
import org .opensearch .search .internal .SearchContext ;
72
+ import org .opensearch .search .startree .StarTreeQueryHelper ;
73
+ import org .opensearch .search .startree .StarTreeTraversalUtil ;
74
+ import org .opensearch .search .startree .filter .DimensionFilter ;
62
75
63
76
import java .io .IOException ;
64
77
import java .util .ArrayList ;
70
83
71
84
import static org .opensearch .core .xcontent .ConstructingObjectParser .optionalConstructorArg ;
72
85
import static org .opensearch .search .aggregations .bucket .filterrewrite .AggregatorBridge .segmentMatchAll ;
86
+ import static org .opensearch .search .startree .StarTreeQueryHelper .getSupportedStarTree ;
73
87
74
88
/**
75
89
* Aggregate all docs that match given ranges.
76
90
*
77
91
* @opensearch.internal
78
92
*/
79
- public class RangeAggregator extends BucketsAggregator {
93
+ public class RangeAggregator extends BucketsAggregator implements StarTreePreComputeCollector {
80
94
81
95
public static final ParseField RANGES_FIELD = new ParseField ("ranges" );
82
96
public static final ParseField KEYED_FIELD = new ParseField ("keyed" );
97
+ public final String fieldName ;
83
98
84
99
/**
85
100
* Range for the range aggregator
@@ -298,6 +313,9 @@ protected Function<Object, Long> bucketOrdProducer() {
298
313
}
299
314
};
300
315
filterRewriteOptimizationContext = new FilterRewriteOptimizationContext (bridge , parent , subAggregators .length , context );
316
+ this .fieldName = (valuesSource instanceof ValuesSource .Numeric .FieldData )
317
+ ? ((ValuesSource .Numeric .FieldData ) valuesSource ).getIndexFieldName ()
318
+ : null ;
301
319
}
302
320
303
321
@ Override
@@ -310,8 +328,13 @@ public ScoreMode scoreMode() {
310
328
311
329
@ Override
312
330
protected boolean tryPrecomputeAggregationForLeaf (LeafReaderContext ctx ) throws IOException {
313
- if (segmentMatchAll (context , ctx )) {
314
- return filterRewriteOptimizationContext .tryOptimize (ctx , this ::incrementBucketDocCount , false );
331
+ if (segmentMatchAll (context , ctx ) && filterRewriteOptimizationContext .tryOptimize (ctx , this ::incrementBucketDocCount , false )) {
332
+ return true ;
333
+ }
334
+ CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree (this .context .getQueryShardContext ());
335
+ if (supportedStarTree != null ) {
336
+ preComputeWithStarTree (ctx , supportedStarTree );
337
+ return true ;
315
338
}
316
339
return false ;
317
340
}
@@ -333,52 +356,106 @@ public void collect(int doc, long bucket) throws IOException {
333
356
}
334
357
335
358
private int collect (int doc , double value , long owningBucketOrdinal , int lowBound ) throws IOException {
336
- int lo = lowBound , hi = ranges .length - 1 ; // all candidates are between these indexes
337
- int mid = (lo + hi ) >>> 1 ;
338
- while (lo <= hi ) {
339
- if (value < ranges [mid ].from ) {
340
- hi = mid - 1 ;
341
- } else if (value >= maxTo [mid ]) {
342
- lo = mid + 1 ;
343
- } else {
344
- break ;
359
+ MatchedRange range = new MatchedRange (ranges , lowBound , value );
360
+ for (int i = range .startLo ; i <= range .endHi ; ++i ) {
361
+ if (ranges [i ].matches (value )) {
362
+ collectBucket (sub , doc , subBucketOrdinal (owningBucketOrdinal , i ));
345
363
}
346
- mid = (lo + hi ) >>> 1 ;
347
364
}
348
- if (lo > hi ) return lo ; // no potential candidate
349
-
350
- // binary search the lower bound
351
- int startLo = lo , startHi = mid ;
352
- while (startLo <= startHi ) {
353
- final int startMid = (startLo + startHi ) >>> 1 ;
354
- if (value >= maxTo [startMid ]) {
355
- startLo = startMid + 1 ;
356
- } else {
357
- startHi = startMid - 1 ;
358
- }
365
+ return range .endHi + 1 ;
366
+ }
367
+ };
368
+ }
369
+
370
+ private void preComputeWithStarTree (LeafReaderContext ctx , CompositeIndexFieldInfo starTree ) throws IOException {
371
+ StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector (ctx , starTree , null );
372
+ FixedBitSet matchingDocsBitSet = starTreeBucketCollector .getMatchingDocsBitSet ();
373
+
374
+ int numBits = matchingDocsBitSet .length ();
375
+
376
+ if (numBits > 0 ) {
377
+ for (int bit = matchingDocsBitSet .nextSetBit (0 ); bit != DocIdSetIterator .NO_MORE_DOCS ; bit = (bit + 1 < numBits )
378
+ ? matchingDocsBitSet .nextSetBit (bit + 1 )
379
+ : DocIdSetIterator .NO_MORE_DOCS ) {
380
+ starTreeBucketCollector .collectStarTreeEntry (bit , 0 );
381
+ }
382
+ }
383
+ }
384
+
385
+ @ Override
386
+ public StarTreeBucketCollector getStarTreeBucketCollector (
387
+ LeafReaderContext ctx ,
388
+ CompositeIndexFieldInfo starTree ,
389
+ StarTreeBucketCollector parentCollector
390
+ ) throws IOException {
391
+ assert parentCollector == null ;
392
+ StarTreeValues starTreeValues = StarTreeQueryHelper .getStarTreeValues (ctx , starTree );
393
+ return new StarTreeBucketCollector (
394
+ starTreeValues ,
395
+ StarTreeTraversalUtil .getStarTreeResult (
396
+ starTreeValues ,
397
+ StarTreeQueryHelper .mergeDimensionFilterIfNotExists (
398
+ context .getQueryShardContext ().getStarTreeQueryContext ().getBaseQueryStarTreeFilter (),
399
+ fieldName ,
400
+ List .of (DimensionFilter .MATCH_ALL_DEFAULT )
401
+ ),
402
+ context
403
+ )
404
+ ) {
405
+ @ Override
406
+ public void setSubCollectors () throws IOException {
407
+ for (Aggregator aggregator : subAggregators ) {
408
+ this .subCollectors .add (((StarTreePreComputeCollector ) aggregator ).getStarTreeBucketCollector (ctx , starTree , this ));
359
409
}
410
+ }
411
+
412
+ SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
413
+ .getDimensionValuesIterator (fieldName );
414
+
415
+ String metricName = StarTreeUtils .fullyQualifiedFieldNameForStarTreeMetricsDocValues (
416
+ starTree .getField (),
417
+ "_doc_count" ,
418
+ MetricStat .DOC_COUNT .getTypeName ()
419
+ );
420
+
421
+ SortedNumericStarTreeValuesIterator docCountsIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
422
+ .getMetricValuesIterator (metricName );
360
423
361
- // binary search the upper bound
362
- int endLo = mid , endHi = hi ;
363
- while (endLo <= endHi ) {
364
- final int endMid = (endLo + endHi ) >>> 1 ;
365
- if (value < ranges [endMid ].from ) {
366
- endHi = endMid - 1 ;
424
+ @ Override
425
+ public void collectStarTreeEntry (int starTreeEntry , long owningBucketOrd ) throws IOException {
426
+ if (!valuesIterator .advanceExact (starTreeEntry )) {
427
+ return ;
428
+ }
429
+
430
+ for (int i = 0 , count = valuesIterator .entryValueCount (); i < count ; i ++) {
431
+ long dimensionLongValue = valuesIterator .nextValue ();
432
+ double dimensionValue ;
433
+
434
+ // Only numeric & floating points are supported as of now in star-tree
435
+ // TODO: Add support for isBigInteger() when it gets supported in star-tree
436
+ if (valuesSource .isFloatingPoint ()) {
437
+ dimensionValue = ((NumberFieldMapper .NumberFieldType ) context .mapperService ().fieldType (fieldName )).toDoubleValue (
438
+ dimensionLongValue
439
+ );
367
440
} else {
368
- endLo = endMid + 1 ;
441
+ dimensionValue = dimensionLongValue ;
369
442
}
370
- }
371
443
372
- assert startLo == lowBound || value >= maxTo [startLo - 1 ];
373
- assert endHi == ranges .length - 1 || value < ranges [endHi + 1 ].from ;
444
+ MatchedRange matchedRange = new MatchedRange (ranges , 0 , dimensionValue );
445
+ if (matchedRange .startLo > matchedRange .endHi ) {
446
+ continue ; // No matching range
447
+ }
374
448
375
- for (int i = startLo ; i <= endHi ; ++i ) {
376
- if (ranges [i ].matches (value )) {
377
- collectBucket (sub , doc , subBucketOrdinal (owningBucketOrdinal , i ));
449
+ if (docCountsIterator .advanceExact (starTreeEntry )) {
450
+ long metricValue = docCountsIterator .nextValue ();
451
+ for (int j = matchedRange .startLo ; j <= matchedRange .endHi ; ++j ) {
452
+ if (ranges [j ].matches (dimensionValue )) {
453
+ long bucketOrd = subBucketOrdinal (owningBucketOrd , j );
454
+ collectStarTreeBucket (this , metricValue , bucketOrd , starTreeEntry );
455
+ }
456
+ }
378
457
}
379
458
}
380
-
381
- return endHi + 1 ;
382
459
}
383
460
};
384
461
}
@@ -421,6 +498,60 @@ public InternalAggregation buildEmptyAggregation() {
421
498
return rangeFactory .create (name , buckets , format , keyed , metadata ());
422
499
}
423
500
501
+ class MatchedRange {
502
+ int startLo , endHi ;
503
+
504
+ MatchedRange (RangeAggregator .Range [] ranges , int lowBound , double value ) {
505
+ computeMatchingRange (ranges , lowBound , value );
506
+ }
507
+
508
+ private void computeMatchingRange (RangeAggregator .Range [] ranges , int lowBound , double value ) {
509
+ int lo = lowBound , hi = ranges .length - 1 ;
510
+ int mid = (lo + hi ) >>> 1 ;
511
+
512
+ while (lo <= hi ) {
513
+ if (value < ranges [mid ].from ) {
514
+ hi = mid - 1 ;
515
+ } else if (value >= maxTo [mid ]) {
516
+ lo = mid + 1 ;
517
+ } else {
518
+ break ;
519
+ }
520
+ mid = (lo + hi ) >>> 1 ;
521
+ }
522
+ if (lo > hi ) {
523
+ this .startLo = lo ;
524
+ this .endHi = lo - 1 ;
525
+ return ;
526
+ }
527
+
528
+ // binary search the lower bound
529
+ int startLo = lo , startHi = mid ;
530
+ while (startLo <= startHi ) {
531
+ int startMid = (startLo + startHi ) >>> 1 ;
532
+ if (value >= maxTo [startMid ]) {
533
+ startLo = startMid + 1 ;
534
+ } else {
535
+ startHi = startMid - 1 ;
536
+ }
537
+ }
538
+
539
+ // binary search the upper bound
540
+ int endLo = mid , endHi = hi ;
541
+ while (endLo <= endHi ) {
542
+ int endMid = (endLo + endHi ) >>> 1 ;
543
+ if (value < ranges [endMid ].from ) {
544
+ endHi = endMid - 1 ;
545
+ } else {
546
+ endLo = endMid + 1 ;
547
+ }
548
+ }
549
+
550
+ this .startLo = startLo ;
551
+ this .endHi = endHi ;
552
+ }
553
+ }
554
+
424
555
/**
425
556
* Unmapped range
426
557
*
@@ -456,7 +587,7 @@ public Unmapped(
456
587
public InternalAggregation buildEmptyAggregation () {
457
588
InternalAggregations subAggs = buildEmptySubAggregations ();
458
589
List <org .opensearch .search .aggregations .bucket .range .Range .Bucket > buckets = new ArrayList <>(ranges .length );
459
- for (RangeAggregator . Range range : ranges ) {
590
+ for (Range range : ranges ) {
460
591
buckets .add (factory .createBucket (range .key , range .from , range .to , 0 , subAggs , keyed , format ));
461
592
}
462
593
return factory .create (name , buckets , format , keyed , metadata ());
0 commit comments