@@ -169,7 +169,8 @@ public void testQueryScoreForFaissWithModel() {
169
169
when (modelDao .getMetadata (eq ("modelId" ))).thenReturn (modelMetadata );
170
170
171
171
KNNWeight .initialize (modelDao );
172
- final KNNWeight knnWeight = new KNNWeight (query , 0.0f );
172
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
173
+ final KNNWeight knnWeight = new KNNWeight (query , boost );
173
174
174
175
final LeafReaderContext leafReaderContext = mock (LeafReaderContext .class );
175
176
final SegmentReader reader = mock (SegmentReader .class );
@@ -214,7 +215,7 @@ public void testQueryScoreForFaissWithModel() {
214
215
final Map <Integer , Float > translatedScores = getTranslatedScores (scoreTranslator );
215
216
for (int docId = docIdSetIterator .nextDoc (); docId != NO_MORE_DOCS ; docId = docIdSetIterator .nextDoc ()) {
216
217
actualDocIds .add (docId );
217
- assertEquals (translatedScores .get (docId ), knnScorer .score (), 0.01f );
218
+ assertEquals (translatedScores .get (docId ) * boost , knnScorer .score (), 0.01f );
218
219
}
219
220
assertEquals (docIdSetIterator .cost (), actualDocIds .size ());
220
221
assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
@@ -364,7 +365,8 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
364
365
// Just to make sure that we are not hitting the exact search condition
365
366
when (filterScorer .iterator ()).thenReturn (DocIdSetIterator .all (filterDocIds .length + 1 ));
366
367
367
- final KNNWeight knnWeight = new KNNWeight (query , 0.0f , filterQueryWeight );
368
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
369
+ final KNNWeight knnWeight = new KNNWeight (query , boost , filterQueryWeight );
368
370
369
371
final FSDirectory directory = mock (FSDirectory .class );
370
372
when (reader .directory ()).thenReturn (directory );
@@ -408,7 +410,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
408
410
final Map <Integer , Float > translatedScores = getTranslatedScores (SpaceType .L2 ::scoreTranslation );
409
411
for (int docId = docIdSetIterator .nextDoc (); docId != NO_MORE_DOCS ; docId = docIdSetIterator .nextDoc ()) {
410
412
actualDocIds .add (docId );
411
- assertEquals (translatedScores .get (docId ), knnScorer .score (), 0.01f );
413
+ assertEquals (translatedScores .get (docId ) * boost , knnScorer .score (), 0.01f );
412
414
}
413
415
assertEquals (docIdSetIterator .cost (), actualDocIds .size ());
414
416
assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
@@ -433,7 +435,8 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() {
433
435
when (reader .getLiveDocs ()).thenReturn (liveDocsBits );
434
436
when (liveDocsBits .get (filterDocId )).thenReturn (true );
435
437
436
- final KNNWeight knnWeight = new KNNWeight (query , 0.0f , filterQueryWeight );
438
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
439
+ final KNNWeight knnWeight = new KNNWeight (query , boost , filterQueryWeight );
437
440
final Map <String , String > attributesMap = ImmutableMap .of (KNN_ENGINE , KNNEngine .FAISS .getName (), SPACE_TYPE , SpaceType .L2 .name ());
438
441
final FieldInfos fieldInfos = mock (FieldInfos .class );
439
442
final FieldInfo fieldInfo = mock (FieldInfo .class );
@@ -457,7 +460,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() {
457
460
final List <Integer > actualDocIds = new ArrayList <>();
458
461
for (int docId = docIdSetIterator .nextDoc (); docId != NO_MORE_DOCS ; docId = docIdSetIterator .nextDoc ()) {
459
462
actualDocIds .add (docId );
460
- assertEquals (EXACT_SEARCH_DOC_ID_TO_SCORES .get (docId ), knnScorer .score (), 0.01f );
463
+ assertEquals (EXACT_SEARCH_DOC_ID_TO_SCORES .get (docId ) * boost , knnScorer .score (), 0.01f );
461
464
}
462
465
assertEquals (docIdSetIterator .cost (), actualDocIds .size ());
463
466
assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
@@ -483,7 +486,8 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
483
486
when (reader .getLiveDocs ()).thenReturn (liveDocsBits );
484
487
when (liveDocsBits .get (filterDocId )).thenReturn (true );
485
488
486
- final KNNWeight knnWeight = new KNNWeight (query , 0.0f , filterQueryWeight );
489
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
490
+ final KNNWeight knnWeight = new KNNWeight (query , boost , filterQueryWeight );
487
491
final Map <String , String > attributesMap = ImmutableMap .of (KNN_ENGINE , KNNEngine .FAISS .getName (), SPACE_TYPE , SpaceType .L2 .name ());
488
492
final FieldInfos fieldInfos = mock (FieldInfos .class );
489
493
final FieldInfo fieldInfo = mock (FieldInfo .class );
@@ -507,7 +511,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
507
511
final List <Integer > actualDocIds = new ArrayList <>();
508
512
for (int docId = docIdSetIterator .nextDoc (); docId != NO_MORE_DOCS ; docId = docIdSetIterator .nextDoc ()) {
509
513
actualDocIds .add (docId );
510
- assertEquals (EXACT_SEARCH_DOC_ID_TO_SCORES .get (docId ), knnScorer .score (), 0.01f );
514
+ assertEquals (EXACT_SEARCH_DOC_ID_TO_SCORES .get (docId ) * boost , knnScorer .score (), 0.01f );
511
515
}
512
516
assertEquals (docIdSetIterator .cost (), actualDocIds .size ());
513
517
assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
@@ -543,7 +547,8 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
543
547
544
548
final KNNQuery query = new KNNQuery (FIELD_NAME , QUERY_VECTOR , k , INDEX_NAME , FILTER_QUERY , null );
545
549
546
- final KNNWeight knnWeight = new KNNWeight (query , 0.0f , filterQueryWeight );
550
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
551
+ final KNNWeight knnWeight = new KNNWeight (query , boost , filterQueryWeight );
547
552
final Map <String , String > attributesMap = ImmutableMap .of (KNN_ENGINE , KNNEngine .FAISS .getName (), SPACE_TYPE , SpaceType .L2 .name ());
548
553
final FieldInfos fieldInfos = mock (FieldInfos .class );
549
554
final FieldInfo fieldInfo = mock (FieldInfo .class );
@@ -567,7 +572,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
567
572
final List <Integer > actualDocIds = new ArrayList <>();
568
573
for (int docId = docIdSetIterator .nextDoc (); docId != NO_MORE_DOCS ; docId = docIdSetIterator .nextDoc ()) {
569
574
actualDocIds .add (docId );
570
- assertEquals (EXACT_SEARCH_DOC_ID_TO_SCORES .get (docId ), knnScorer .score (), 0.01f );
575
+ assertEquals (EXACT_SEARCH_DOC_ID_TO_SCORES .get (docId ) * boost , knnScorer .score (), 0.01f );
571
576
}
572
577
assertEquals (docIdSetIterator .cost (), actualDocIds .size ());
573
578
assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
@@ -631,7 +636,8 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
631
636
when (filterQueryWeight .scorer (leafReaderContext )).thenReturn (filterScorer );
632
637
633
638
final KNNQuery query = new KNNQuery (FIELD_NAME , QUERY_VECTOR , K , INDEX_NAME , FILTER_QUERY , parentFilter );
634
- final KNNWeight knnWeight = new KNNWeight (query , 0.0f , filterQueryWeight );
639
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
640
+ final KNNWeight knnWeight = new KNNWeight (query , boost , filterQueryWeight );
635
641
636
642
// Execute
637
643
final KNNScorer knnScorer = (KNNScorer ) knnWeight .scorer (leafReaderContext );
@@ -642,7 +648,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
642
648
.collect (Collectors .toList ());
643
649
final DocIdSetIterator docIdSetIterator = knnScorer .iterator ();
644
650
assertEquals (1 , docIdSetIterator .nextDoc ());
645
- assertEquals (expectedScores .get (1 ), knnScorer .score (), 0.01f );
651
+ assertEquals (expectedScores .get (1 ) * boost , knnScorer .score (), 0.01f );
646
652
assertEquals (NO_MORE_DOCS , docIdSetIterator .nextDoc ());
647
653
}
648
654
@@ -733,7 +739,8 @@ private void testQueryScore(
733
739
.thenReturn (getKNNQueryResults ());
734
740
735
741
final KNNQuery query = new KNNQuery (FIELD_NAME , QUERY_VECTOR , K , INDEX_NAME , (BitSetProducer ) null );
736
- final KNNWeight knnWeight = new KNNWeight (query , 0.0f );
742
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
743
+ final KNNWeight knnWeight = new KNNWeight (query , boost );
737
744
738
745
final LeafReaderContext leafReaderContext = mock (LeafReaderContext .class );
739
746
final SegmentReader reader = mock (SegmentReader .class );
@@ -777,7 +784,7 @@ private void testQueryScore(
777
784
final Map <Integer , Float > translatedScores = getTranslatedScores (scoreTranslator );
778
785
for (int docId = docIdSetIterator .nextDoc (); docId != NO_MORE_DOCS ; docId = docIdSetIterator .nextDoc ()) {
779
786
actualDocIds .add (docId );
780
- assertEquals (translatedScores .get (docId ), knnScorer .score (), 0.01f );
787
+ assertEquals (translatedScores .get (docId ) * boost , knnScorer .score (), 0.01f );
781
788
}
782
789
assertEquals (docIdSetIterator .cost (), actualDocIds .size ());
783
790
assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
0 commit comments