32
32
package org .opensearch .search .aggregations .bucket .range ;
33
33
34
34
import org .apache .lucene .index .LeafReaderContext ;
35
+ import org .apache .lucene .search .DocIdSetIterator ;
35
36
import org .apache .lucene .search .ScoreMode ;
37
+ import org .apache .lucene .util .FixedBitSet ;
36
38
import org .opensearch .core .ParseField ;
37
39
import org .opensearch .core .common .io .stream .StreamInput ;
38
40
import org .opensearch .core .common .io .stream .StreamOutput ;
43
45
import org .opensearch .core .xcontent .ToXContentObject ;
44
46
import org .opensearch .core .xcontent .XContentBuilder ;
45
47
import org .opensearch .core .xcontent .XContentParser ;
48
+ import org .opensearch .index .codec .composite .CompositeIndexFieldInfo ;
49
+ import org .opensearch .index .compositeindex .datacube .MetricStat ;
50
+ import org .opensearch .index .compositeindex .datacube .startree .index .StarTreeValues ;
51
+ import org .opensearch .index .compositeindex .datacube .startree .utils .StarTreeUtils ;
52
+ import org .opensearch .index .compositeindex .datacube .startree .utils .iterator .SortedNumericStarTreeValuesIterator ;
46
53
import org .opensearch .index .fielddata .SortedNumericDoubleValues ;
54
+ import org .opensearch .index .mapper .NumberFieldMapper ;
47
55
import org .opensearch .search .DocValueFormat ;
48
56
import org .opensearch .search .aggregations .Aggregator ;
49
57
import org .opensearch .search .aggregations .AggregatorFactories ;
53
61
import org .opensearch .search .aggregations .LeafBucketCollector ;
54
62
import org .opensearch .search .aggregations .LeafBucketCollectorBase ;
55
63
import org .opensearch .search .aggregations .NonCollectingAggregator ;
64
+ import org .opensearch .search .aggregations .StarTreeBucketCollector ;
65
+ import org .opensearch .search .aggregations .StarTreePreComputeCollector ;
56
66
import org .opensearch .search .aggregations .bucket .BucketsAggregator ;
57
67
import org .opensearch .search .aggregations .bucket .filterrewrite .FilterRewriteOptimizationContext ;
58
68
import org .opensearch .search .aggregations .bucket .filterrewrite .RangeAggregatorBridge ;
59
69
import org .opensearch .search .aggregations .support .ValuesSource ;
60
70
import org .opensearch .search .aggregations .support .ValuesSourceConfig ;
61
71
import org .opensearch .search .internal .SearchContext ;
72
+ import org .opensearch .search .startree .StarTreeQueryHelper ;
73
+ import org .opensearch .search .startree .StarTreeTraversalUtil ;
74
+ import org .opensearch .search .startree .filter .DimensionFilter ;
62
75
63
76
import java .io .IOException ;
64
77
import java .util .ArrayList ;
70
83
71
84
import static org .opensearch .core .xcontent .ConstructingObjectParser .optionalConstructorArg ;
72
85
import static org .opensearch .search .aggregations .bucket .filterrewrite .AggregatorBridge .segmentMatchAll ;
86
+ import static org .opensearch .search .startree .StarTreeQueryHelper .getSupportedStarTree ;
73
87
74
88
/**
75
89
* Aggregate all docs that match given ranges.
76
90
*
77
91
* @opensearch.internal
78
92
*/
79
- public class RangeAggregator extends BucketsAggregator {
93
+ public class RangeAggregator extends BucketsAggregator implements StarTreePreComputeCollector {
80
94
81
95
public static final ParseField RANGES_FIELD = new ParseField ("ranges" );
82
96
public static final ParseField KEYED_FIELD = new ParseField ("keyed" );
97
+ public final String fieldName ;
83
98
84
99
/**
85
100
* Range for the range aggregator
@@ -298,6 +313,9 @@ protected Function<Object, Long> bucketOrdProducer() {
298
313
}
299
314
};
300
315
filterRewriteOptimizationContext = new FilterRewriteOptimizationContext (bridge , parent , subAggregators .length , context );
316
+ this .fieldName = (valuesSource instanceof ValuesSource .Numeric .FieldData )
317
+ ? ((ValuesSource .Numeric .FieldData ) valuesSource ).getIndexFieldName ()
318
+ : null ;
301
319
}
302
320
303
321
@ Override
@@ -313,6 +331,11 @@ protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws
313
331
if (segmentMatchAll (context , ctx )) {
314
332
return filterRewriteOptimizationContext .tryOptimize (ctx , this ::incrementBucketDocCount , false );
315
333
}
334
+ CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree (this .context .getQueryShardContext ());
335
+ if (supportedStarTree != null ) {
336
+ preComputeWithStarTree (ctx , supportedStarTree );
337
+ return true ;
338
+ }
316
339
return false ;
317
340
}
318
341
@@ -333,52 +356,141 @@ public void collect(int doc, long bucket) throws IOException {
333
356
}
334
357
335
358
private int collect (int doc , double value , long owningBucketOrdinal , int lowBound ) throws IOException {
336
- int lo = lowBound , hi = ranges .length - 1 ; // all candidates are between these indexes
337
- int mid = (lo + hi ) >>> 1 ;
338
- while (lo <= hi ) {
339
- if (value < ranges [mid ].from ) {
340
- hi = mid - 1 ;
341
- } else if (value >= maxTo [mid ]) {
342
- lo = mid + 1 ;
343
- } else {
344
- break ;
359
+ MatchedRange range = new MatchedRange (ranges , lowBound , value );
360
+ for (int i = range .startLo ; i <= range .endHi ; ++i ) {
361
+ if (ranges [i ].matches (value )) {
362
+ collectBucket (sub , doc , subBucketOrdinal (owningBucketOrdinal , i ));
345
363
}
346
- mid = (lo + hi ) >>> 1 ;
347
364
}
348
- if (lo > hi ) return lo ; // no potential candidate
349
-
350
- // binary search the lower bound
351
- int startLo = lo , startHi = mid ;
352
- while (startLo <= startHi ) {
353
- final int startMid = (startLo + startHi ) >>> 1 ;
354
- if (value >= maxTo [startMid ]) {
355
- startLo = startMid + 1 ;
356
- } else {
357
- startHi = startMid - 1 ;
358
- }
365
+ return range .endHi + 1 ;
366
+ }
367
+ };
368
+ }
369
+
370
+ private void preComputeWithStarTree (LeafReaderContext ctx , CompositeIndexFieldInfo starTree ) throws IOException {
371
+ StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector (ctx , starTree , null );
372
+ FixedBitSet matchingDocsBitSet = starTreeBucketCollector .getMatchingDocsBitSet ();
373
+
374
+ int numBits = matchingDocsBitSet .length ();
375
+
376
+ if (numBits > 0 ) {
377
+ for (int bit = matchingDocsBitSet .nextSetBit (0 ); bit != DocIdSetIterator .NO_MORE_DOCS ; bit = (bit + 1 < numBits )
378
+ ? matchingDocsBitSet .nextSetBit (bit + 1 )
379
+ : DocIdSetIterator .NO_MORE_DOCS ) {
380
+ starTreeBucketCollector .collectStarTreeEntry (bit , 0 );
381
+ }
382
+ }
383
+ }
384
+
385
+ @ Override
386
+ public StarTreeBucketCollector getStarTreeBucketCollector (
387
+ LeafReaderContext ctx ,
388
+ CompositeIndexFieldInfo starTree ,
389
+ StarTreeBucketCollector parentCollector
390
+ ) throws IOException {
391
+ assert parentCollector == null ;
392
+ StarTreeValues starTreeValues = StarTreeQueryHelper .getStarTreeValues (ctx , starTree );
393
+ return new StarTreeBucketCollector (
394
+ starTreeValues ,
395
+ StarTreeTraversalUtil .getStarTreeResult (
396
+ starTreeValues ,
397
+ StarTreeQueryHelper .mergeDimensionFilterIfNotExists (
398
+ context .getQueryShardContext ().getStarTreeQueryContext ().getBaseQueryStarTreeFilter (),
399
+ fieldName ,
400
+ List .of (DimensionFilter .MATCH_ALL_DEFAULT )
401
+ ),
402
+ context
403
+ )
404
+ ) {
405
+ @ Override
406
+ public void setSubCollectors () throws IOException {
407
+ for (Aggregator aggregator : subAggregators ) {
408
+ this .subCollectors .add (((StarTreePreComputeCollector ) aggregator ).getStarTreeBucketCollector (ctx , starTree , this ));
409
+ }
410
+ }
411
+
412
+ SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
413
+ .getDimensionValuesIterator (fieldName );
414
+
415
+ String metricName = StarTreeUtils .fullyQualifiedFieldNameForStarTreeMetricsDocValues (
416
+ starTree .getField (),
417
+ "_doc_count" ,
418
+ MetricStat .DOC_COUNT .getTypeName ()
419
+ );
420
+
421
+ SortedNumericStarTreeValuesIterator docCountsIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
422
+ .getMetricValuesIterator (metricName );
423
+
424
+ @ Override
425
+ public void collectStarTreeEntry (int starTreeEntry , long owningBucketOrd ) throws IOException {
426
+ if (valuesIterator .advanceExact (starTreeEntry ) == false ) {
427
+ return ;
359
428
}
360
429
361
- // binary search the upper bound
362
- int endLo = mid , endHi = hi ;
363
- while (endLo <= endHi ) {
364
- final int endMid = (endLo + endHi ) >>> 1 ;
365
- if (value < ranges [endMid ].from ) {
366
- endHi = endMid - 1 ;
430
+ for (int i = 0 , count = valuesIterator .entryValueCount (); i < count ; i ++) {
431
+ long dimensionLongValue = valuesIterator .nextValue ();
432
+ double dimensionValue ;
433
+
434
+ // Only numeric & floating points are supported as of now in star-tree
435
+ // TODO: Add support for isBigInteger() when it gets supported in star-tree
436
+ if (valuesSource .isFloatingPoint ()) {
437
+ dimensionValue = ((NumberFieldMapper .NumberFieldType ) context .mapperService ().fieldType (fieldName )).toDoubleValue (
438
+ dimensionLongValue
439
+ );
367
440
} else {
368
- endLo = endMid + 1 ;
441
+ dimensionValue = dimensionLongValue ;
369
442
}
370
- }
371
443
372
- assert startLo == lowBound || value >= maxTo [startLo - 1 ];
373
- assert endHi == ranges .length - 1 || value < ranges [endHi + 1 ].from ;
444
+ // The core logic remains largely the same as the original collect method,
445
+ // but adapted for star-tree entry processing.
446
+ int lo = 0 , hi = ranges .length - 1 ;
447
+ int mid = (lo + hi ) >>> 1 ;
448
+
449
+ while (lo <= hi ) {
450
+ if (dimensionValue < ranges [mid ].from ) {
451
+ hi = mid - 1 ;
452
+ } else if (dimensionValue >= maxTo [mid ]) {
453
+ lo = mid + 1 ;
454
+ } else {
455
+ break ;
456
+ }
457
+ mid = (lo + hi ) >>> 1 ;
458
+ }
374
459
375
- for (int i = startLo ; i <= endHi ; ++i ) {
376
- if (ranges [i ].matches (value )) {
377
- collectBucket (sub , doc , subBucketOrdinal (owningBucketOrdinal , i ));
460
+ if (lo > hi ) continue ; // No matching range
461
+
462
+ // binary search the lower bound
463
+ int startLo = lo , startHi = mid ;
464
+ while (startLo <= startHi ) {
465
+ final int startMid = (startLo + startHi ) >>> 1 ;
466
+ if (dimensionValue >= maxTo [startMid ]) {
467
+ startLo = startMid + 1 ;
468
+ } else {
469
+ startHi = startMid - 1 ;
470
+ }
378
471
}
379
- }
380
472
381
- return endHi + 1 ;
473
+ // binary search the upper bound
474
+ int endLo = mid , endHi = hi ;
475
+ while (endLo <= endHi ) {
476
+ final int endMid = (endLo + endHi ) >>> 1 ;
477
+ if (dimensionValue < ranges [endMid ].from ) {
478
+ endHi = endMid - 1 ;
479
+ } else {
480
+ endLo = endMid + 1 ;
481
+ }
482
+ }
483
+
484
+ if (docCountsIterator .advanceExact (starTreeEntry )) {
485
+ long metricValue = docCountsIterator .nextValue ();
486
+ for (int j = startLo ; j <= endHi ; ++j ) {
487
+ if (ranges [j ].matches (dimensionValue )) {
488
+ long bucketOrd = subBucketOrdinal (owningBucketOrd , j );
489
+ collectStarTreeBucket (this , metricValue , bucketOrd , starTreeEntry );
490
+ }
491
+ }
492
+ }
493
+ }
382
494
}
383
495
};
384
496
}
@@ -421,6 +533,61 @@ public InternalAggregation buildEmptyAggregation() {
421
533
return rangeFactory .create (name , buckets , format , keyed , metadata ());
422
534
}
423
535
536
+ class MatchedRange {
537
+ int startLo , endHi ;
538
+
539
+ MatchedRange (RangeAggregator .Range [] ranges , int lowBound , double value ) {
540
+ computeMatchingRange (ranges , lowBound , value );
541
+ }
542
+
543
+ private MatchedRange computeMatchingRange (RangeAggregator .Range [] ranges , int lowBound , double value ) {
544
+ int lo = lowBound , hi = ranges .length - 1 ;
545
+ int mid = (lo + hi ) >>> 1 ;
546
+
547
+ while (lo <= hi ) {
548
+ if (value < ranges [mid ].from ) {
549
+ hi = mid - 1 ;
550
+ } else if (value >= maxTo [mid ]) {
551
+ lo = mid + 1 ;
552
+ } else {
553
+ break ;
554
+ }
555
+ mid = (lo + hi ) >>> 1 ;
556
+ }
557
+ if (lo > hi ) {
558
+ this .startLo = lo ;
559
+ this .endHi = lo - 1 ;
560
+ return this ;
561
+ }
562
+
563
+ // binary search the lower bound
564
+ int startLo = lo , startHi = mid ;
565
+ while (startLo <= startHi ) {
566
+ int startMid = (startLo + startHi ) >>> 1 ;
567
+ if (value >= maxTo [startMid ]) {
568
+ startLo = startMid + 1 ;
569
+ } else {
570
+ startHi = startMid - 1 ;
571
+ }
572
+ }
573
+
574
+ // binary search the upper bound
575
+ int endLo = mid , endHi = hi ;
576
+ while (endLo <= endHi ) {
577
+ int endMid = (endLo + endHi ) >>> 1 ;
578
+ if (value < ranges [endMid ].from ) {
579
+ endHi = endMid - 1 ;
580
+ } else {
581
+ endLo = endMid + 1 ;
582
+ }
583
+ }
584
+
585
+ this .startLo = startLo ;
586
+ this .endHi = endHi ;
587
+ return this ;
588
+ }
589
+ }
590
+
424
591
/**
425
592
* Unmapped range
426
593
*
@@ -456,7 +623,7 @@ public Unmapped(
456
623
public InternalAggregation buildEmptyAggregation () {
457
624
InternalAggregations subAggs = buildEmptySubAggregations ();
458
625
List <org .opensearch .search .aggregations .bucket .range .Range .Bucket > buckets = new ArrayList <>(ranges .length );
459
- for (RangeAggregator . Range range : ranges ) {
626
+ for (Range range : ranges ) {
460
627
buckets .add (factory .createBucket (range .key , range .from , range .to , 0 , subAggs , keyed , format ));
461
628
}
462
629
return factory .create (name , buckets , format , keyed , metadata ());
0 commit comments