@@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
671
671
when (liveDocsBits .length ()).thenReturn (1000 );
672
672
673
673
final SegmentReader reader = mockSegmentReader ();
674
- when (reader .maxDoc ()).thenReturn (filterDocIds .length );
674
+ when (reader .maxDoc ()).thenReturn (filterDocIds .length + 1 );
675
675
when (reader .getLiveDocs ()).thenReturn (liveDocsBits );
676
676
677
677
final LeafReaderContext leafReaderContext = mock (LeafReaderContext .class );
@@ -758,6 +758,88 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
758
758
assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
759
759
}
760
760
761
+ @ SneakyThrows
762
+ public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess () {
763
+ // Given
764
+ int k = 3 ;
765
+ final int [] filterDocIds = new int [] { 0 , 1 , 2 , 3 , 4 , 5 };
766
+ FixedBitSet filterBitSet = new FixedBitSet (filterDocIds .length );
767
+ for (int docId : filterDocIds ) {
768
+ filterBitSet .set (docId );
769
+ }
770
+
771
+ jniServiceMockedStatic .when (
772
+ () -> JNIService .queryIndex (anyLong (), eq (QUERY_VECTOR ), eq (k ), eq (HNSW_METHOD_PARAMETERS ), any (), eq (null ), anyInt (), any ())
773
+ ).thenReturn (getFilteredKNNQueryResults ());
774
+
775
+ final Bits liveDocsBits = mock (Bits .class );
776
+ for (int filterDocId : filterDocIds ) {
777
+ when (liveDocsBits .get (filterDocId )).thenReturn (true );
778
+ }
779
+ when (liveDocsBits .length ()).thenReturn (1000 );
780
+
781
+ final SegmentReader reader = mockSegmentReader ();
782
+ when (reader .maxDoc ()).thenReturn (filterDocIds .length );
783
+ when (reader .getLiveDocs ()).thenReturn (liveDocsBits );
784
+
785
+ final LeafReaderContext leafReaderContext = mock (LeafReaderContext .class );
786
+ when (leafReaderContext .reader ()).thenReturn (reader );
787
+
788
+ final KNNQuery query = KNNQuery .builder ()
789
+ .field (FIELD_NAME )
790
+ .queryVector (QUERY_VECTOR )
791
+ .k (k )
792
+ .indexName (INDEX_NAME )
793
+ .filterQuery (FILTER_QUERY )
794
+ .methodParameters (HNSW_METHOD_PARAMETERS )
795
+ .build ();
796
+
797
+ final Weight filterQueryWeight = mock (Weight .class );
798
+ final Scorer filterScorer = mock (Scorer .class );
799
+ when (filterQueryWeight .scorer (leafReaderContext )).thenReturn (filterScorer );
800
+ // Just to make sure that we are not hitting the exact search condition
801
+ when (filterScorer .iterator ()).thenReturn (DocIdSetIterator .all (filterDocIds .length + 1 ));
802
+
803
+ final float boost = (float ) randomDoubleBetween (0 , 10 , true );
804
+ final KNNWeight knnWeight = new KNNWeight (query , boost , filterQueryWeight );
805
+
806
+ final FieldInfos fieldInfos = mock (FieldInfos .class );
807
+ final FieldInfo fieldInfo = mock (FieldInfo .class );
808
+ final Map <String , String > attributesMap = ImmutableMap .of (
809
+ KNN_ENGINE ,
810
+ KNNEngine .FAISS .getName (),
811
+ SPACE_TYPE ,
812
+ SpaceType .L2 .getValue ()
813
+ );
814
+
815
+ when (reader .getFieldInfos ()).thenReturn (fieldInfos );
816
+ when (fieldInfos .fieldInfo (any ())).thenReturn (fieldInfo );
817
+ when (fieldInfo .attributes ()).thenReturn (attributesMap );
818
+
819
+ // When
820
+ final KNNScorer knnScorer = (KNNScorer ) knnWeight .scorer (leafReaderContext );
821
+
822
+ // Then
823
+ assertNotNull (knnScorer );
824
+ final DocIdSetIterator docIdSetIterator = knnScorer .iterator ();
825
+ assertNotNull (docIdSetIterator );
826
+ assertEquals (FILTERED_DOC_ID_TO_SCORES .size (), docIdSetIterator .cost ());
827
+
828
+ jniServiceMockedStatic .verify (
829
+ () -> JNIService .queryIndex (anyLong (), eq (QUERY_VECTOR ), eq (k ), eq (HNSW_METHOD_PARAMETERS ), any (), any (), anyInt (), any ()),
830
+ times (1 )
831
+ );
832
+
833
+ final List <Integer > actualDocIds = new ArrayList <>();
834
+ final Map <Integer , Float > translatedScores = getTranslatedScores (SpaceType .L2 ::scoreTranslation );
835
+ for (int docId = docIdSetIterator .nextDoc (); docId != NO_MORE_DOCS ; docId = docIdSetIterator .nextDoc ()) {
836
+ actualDocIds .add (docId );
837
+ assertEquals (translatedScores .get (docId ) * boost , knnScorer .score (), 0.01f );
838
+ }
839
+ assertEquals (docIdSetIterator .cost (), actualDocIds .size ());
840
+ assertTrue (Comparators .isInOrder (actualDocIds , Comparator .naturalOrder ()));
841
+ }
842
+
761
843
private SegmentReader mockSegmentReader () {
762
844
Path path = mock (Path .class );
763
845
@@ -815,7 +897,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean
815
897
when (filterQueryWeight .scorer (leafReaderContext )).thenReturn (filterScorer );
816
898
// scorer will return 2 documents
817
899
when (filterScorer .iterator ()).thenReturn (DocIdSetIterator .all (1 ));
818
- when (reader .maxDoc ()).thenReturn (1 );
900
+ when (reader .maxDoc ()).thenReturn (2 );
819
901
final Bits liveDocsBits = mock (Bits .class );
820
902
when (reader .getLiveDocs ()).thenReturn (liveDocsBits );
821
903
when (liveDocsBits .get (filterDocId )).thenReturn (true );
@@ -891,6 +973,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() {
891
973
final LeafReaderContext leafReaderContext = mock (LeafReaderContext .class );
892
974
final SegmentReader reader = mock (SegmentReader .class );
893
975
when (leafReaderContext .reader ()).thenReturn (reader );
976
+ when (reader .maxDoc ()).thenReturn (1 );
894
977
895
978
final FSDirectory directory = mock (FSDirectory .class );
896
979
when (reader .directory ()).thenReturn (directory );
@@ -968,7 +1051,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
968
1051
when (filterQueryWeight .scorer (leafReaderContext )).thenReturn (filterScorer );
969
1052
// scorer will return 2 documents
970
1053
when (filterScorer .iterator ()).thenReturn (DocIdSetIterator .all (1 ));
971
- when (reader .maxDoc ()).thenReturn (1 );
1054
+ when (reader .maxDoc ()).thenReturn (2 );
972
1055
final Bits liveDocsBits = mock (Bits .class );
973
1056
when (reader .getLiveDocs ()).thenReturn (liveDocsBits );
974
1057
when (liveDocsBits .get (filterDocId )).thenReturn (true );
@@ -1168,6 +1251,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
1168
1251
final LeafReaderContext leafReaderContext = mock (LeafReaderContext .class );
1169
1252
final SegmentReader reader = mock (SegmentReader .class );
1170
1253
when (leafReaderContext .reader ()).thenReturn (reader );
1254
+ when (reader .maxDoc ()).thenReturn (1 );
1171
1255
1172
1256
final Weight filterQueryWeight = mock (Weight .class );
1173
1257
final Scorer filterScorer = mock (Scorer .class );
@@ -1202,7 +1286,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
1202
1286
// We will have 0, 1 for filteredIds and 2 will be the parent id for both of them
1203
1287
final Scorer filterScorer = mock (Scorer .class );
1204
1288
when (filterScorer .iterator ()).thenReturn (DocIdSetIterator .all (2 ));
1205
- when (reader .maxDoc ()).thenReturn (2 );
1289
+ when (reader .maxDoc ()).thenReturn (3 );
1206
1290
1207
1291
// Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result
1208
1292
final List <float []> vectors = Arrays .asList (new float [] { 0.1f , 0.3f }, new float [] { 1.9f , 2.5f });
0 commit comments