Skip to content

Commit fcbfef1

Browse files
Fix KNNScorer to apply boost (#1403)
* apply boost Signed-off-by: panguixin <panguixin@bytedance.com> * add change log Signed-off-by: panguixin <panguixin@bytedance.com> --------- Signed-off-by: panguixin <panguixin@bytedance.com>
1 parent 6abec19 commit fcbfef1

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2626
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
2727
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
2828
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
29+
* Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403)
2930
### Infrastructure
3031
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
3132
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)

src/main/java/org/opensearch/knn/index/query/KNNScorer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public float score() {
4949
assert docID() != DocIdSetIterator.NO_MORE_DOCS;
5050
Float score = scores.get(docID());
5151
if (score == null) throw new RuntimeException("Null score for the docID: " + docID());
52-
return score;
52+
return score * boost;
5353
}
5454

5555
@Override

src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java

+21-14
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ public void testQueryScoreForFaissWithModel() {
169169
when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata);
170170

171171
KNNWeight.initialize(modelDao);
172-
final KNNWeight knnWeight = new KNNWeight(query, 0.0f);
172+
final float boost = (float) randomDoubleBetween(0, 10, true);
173+
final KNNWeight knnWeight = new KNNWeight(query, boost);
173174

174175
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
175176
final SegmentReader reader = mock(SegmentReader.class);
@@ -214,7 +215,7 @@ public void testQueryScoreForFaissWithModel() {
214215
final Map<Integer, Float> translatedScores = getTranslatedScores(scoreTranslator);
215216
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
216217
actualDocIds.add(docId);
217-
assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f);
218+
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
218219
}
219220
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
220221
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
@@ -364,7 +365,8 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
364365
// Just to make sure that we are not hitting the exact search condition
365366
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1));
366367

367-
final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
368+
final float boost = (float) randomDoubleBetween(0, 10, true);
369+
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
368370

369371
final FSDirectory directory = mock(FSDirectory.class);
370372
when(reader.directory()).thenReturn(directory);
@@ -408,7 +410,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
408410
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
409411
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
410412
actualDocIds.add(docId);
411-
assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f);
413+
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
412414
}
413415
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
414416
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
@@ -433,7 +435,8 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() {
433435
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
434436
when(liveDocsBits.get(filterDocId)).thenReturn(true);
435437

436-
final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
438+
final float boost = (float) randomDoubleBetween(0, 10, true);
439+
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
437440
final Map<String, String> attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name());
438441
final FieldInfos fieldInfos = mock(FieldInfos.class);
439442
final FieldInfo fieldInfo = mock(FieldInfo.class);
@@ -457,7 +460,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() {
457460
final List<Integer> actualDocIds = new ArrayList<>();
458461
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
459462
actualDocIds.add(docId);
460-
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f);
463+
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
461464
}
462465
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
463466
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
@@ -483,7 +486,8 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
483486
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
484487
when(liveDocsBits.get(filterDocId)).thenReturn(true);
485488

486-
final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
489+
final float boost = (float) randomDoubleBetween(0, 10, true);
490+
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
487491
final Map<String, String> attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name());
488492
final FieldInfos fieldInfos = mock(FieldInfos.class);
489493
final FieldInfo fieldInfo = mock(FieldInfo.class);
@@ -507,7 +511,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
507511
final List<Integer> actualDocIds = new ArrayList<>();
508512
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
509513
actualDocIds.add(docId);
510-
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f);
514+
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
511515
}
512516
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
513517
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
@@ -543,7 +547,8 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
543547

544548
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null);
545549

546-
final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
550+
final float boost = (float) randomDoubleBetween(0, 10, true);
551+
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
547552
final Map<String, String> attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name());
548553
final FieldInfos fieldInfos = mock(FieldInfos.class);
549554
final FieldInfo fieldInfo = mock(FieldInfo.class);
@@ -567,7 +572,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
567572
final List<Integer> actualDocIds = new ArrayList<>();
568573
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
569574
actualDocIds.add(docId);
570-
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f);
575+
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
571576
}
572577
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
573578
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
@@ -631,7 +636,8 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
631636
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
632637

633638
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, parentFilter);
634-
final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
639+
final float boost = (float) randomDoubleBetween(0, 10, true);
640+
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
635641

636642
// Execute
637643
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
@@ -642,7 +648,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
642648
.collect(Collectors.toList());
643649
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
644650
assertEquals(1, docIdSetIterator.nextDoc());
645-
assertEquals(expectedScores.get(1), knnScorer.score(), 0.01f);
651+
assertEquals(expectedScores.get(1) * boost, knnScorer.score(), 0.01f);
646652
assertEquals(NO_MORE_DOCS, docIdSetIterator.nextDoc());
647653
}
648654

@@ -733,7 +739,8 @@ private void testQueryScore(
733739
.thenReturn(getKNNQueryResults());
734740

735741
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null);
736-
final KNNWeight knnWeight = new KNNWeight(query, 0.0f);
742+
final float boost = (float) randomDoubleBetween(0, 10, true);
743+
final KNNWeight knnWeight = new KNNWeight(query, boost);
737744

738745
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
739746
final SegmentReader reader = mock(SegmentReader.class);
@@ -777,7 +784,7 @@ private void testQueryScore(
777784
final Map<Integer, Float> translatedScores = getTranslatedScores(scoreTranslator);
778785
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
779786
actualDocIds.add(docId);
780-
assertEquals(translatedScores.get(docId), knnScorer.score(), 0.01f);
787+
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
781788
}
782789
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
783790
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));

0 commit comments

Comments
 (0)