Skip to content

Commit 12398c4

Browse files
weiwang118Wei Wang
and
Wei Wang
authored
Add check to directly use ANN Search when filters match all docs. (opensearch-project#2320) (opensearch-project#2367)
* Add check to directly use ANN Search when filters match all docs. * Fix failed tests and rebase on main branch * pass filterbitset as null and add integ tests. --------- (cherry picked from commit 6f5313f) Signed-off-by: Wei Wang <weiwsde@gmail.com> Co-authored-by: Wei Wang <weiwangv@amazon.com>
1 parent 72c6a1e commit 12398c4

File tree

5 files changed

+159
-6
lines changed

5 files changed

+159
-6
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2222
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
2323
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
2424
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
25+
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
2526
### Bug Fixes
2627
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
2728
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]

src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ public enum FilterIdsSelectorType {
7878
public static FilterIdsSelector getFilterIdSelector(final BitSet filterIdsBitSet, final int cardinality) throws IOException {
7979
long[] filterIds;
8080
FilterIdsSelector.FilterIdsSelectorType filterType;
81-
if (filterIdsBitSet instanceof FixedBitSet) {
81+
if (filterIdsBitSet == null) {
82+
filterIds = null;
83+
filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP;
84+
} else if (filterIdsBitSet instanceof FixedBitSet) {
8285
/**
8386
* When filterIds is dense filter, using fixed bitset
8487
*/

src/main/java/org/opensearch/knn/index/query/KNNWeight.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
129129
*/
130130
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
131131
final BitSet filterBitSet = getFilteredDocsBitSet(context);
132+
final int maxDoc = context.reader().maxDoc();
132133
int cardinality = filterBitSet.cardinality();
133134
// We don't need to go to JNI layer if no documents are found which satisfy the filters
134135
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
@@ -145,7 +146,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
145146
Map<Integer, Float> result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
146147
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
147148
}
148-
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
149+
150+
/*
151+
* If filters match all docs in this segment, then null should be passed as filterBitSet
152+
* so that it will not do a bitset look up in bottom search layer.
153+
*/
154+
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
155+
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);
156+
149157
// See whether we have to perform exact search based on approx search results
150158
// This is required if there are no native engine files or if approximate search returned
151159
// results less than K, though we have more than k filtered docs

src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java

+88-4
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
671671
when(liveDocsBits.length()).thenReturn(1000);
672672

673673
final SegmentReader reader = mockSegmentReader();
674-
when(reader.maxDoc()).thenReturn(filterDocIds.length);
674+
when(reader.maxDoc()).thenReturn(filterDocIds.length + 1);
675675
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
676676

677677
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
@@ -758,6 +758,88 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
758758
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
759759
}
760760

761+
@SneakyThrows
762+
public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() {
763+
// Given
764+
int k = 3;
765+
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };
766+
FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length);
767+
for (int docId : filterDocIds) {
768+
filterBitSet.set(docId);
769+
}
770+
771+
jniServiceMockedStatic.when(
772+
() -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any())
773+
).thenReturn(getFilteredKNNQueryResults());
774+
775+
final Bits liveDocsBits = mock(Bits.class);
776+
for (int filterDocId : filterDocIds) {
777+
when(liveDocsBits.get(filterDocId)).thenReturn(true);
778+
}
779+
when(liveDocsBits.length()).thenReturn(1000);
780+
781+
final SegmentReader reader = mockSegmentReader();
782+
when(reader.maxDoc()).thenReturn(filterDocIds.length);
783+
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
784+
785+
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
786+
when(leafReaderContext.reader()).thenReturn(reader);
787+
788+
final KNNQuery query = KNNQuery.builder()
789+
.field(FIELD_NAME)
790+
.queryVector(QUERY_VECTOR)
791+
.k(k)
792+
.indexName(INDEX_NAME)
793+
.filterQuery(FILTER_QUERY)
794+
.methodParameters(HNSW_METHOD_PARAMETERS)
795+
.build();
796+
797+
final Weight filterQueryWeight = mock(Weight.class);
798+
final Scorer filterScorer = mock(Scorer.class);
799+
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
800+
// Just to make sure that we are not hitting the exact search condition
801+
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1));
802+
803+
final float boost = (float) randomDoubleBetween(0, 10, true);
804+
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
805+
806+
final FieldInfos fieldInfos = mock(FieldInfos.class);
807+
final FieldInfo fieldInfo = mock(FieldInfo.class);
808+
final Map<String, String> attributesMap = ImmutableMap.of(
809+
KNN_ENGINE,
810+
KNNEngine.FAISS.getName(),
811+
SPACE_TYPE,
812+
SpaceType.L2.getValue()
813+
);
814+
815+
when(reader.getFieldInfos()).thenReturn(fieldInfos);
816+
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
817+
when(fieldInfo.attributes()).thenReturn(attributesMap);
818+
819+
// When
820+
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
821+
822+
// Then
823+
assertNotNull(knnScorer);
824+
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
825+
assertNotNull(docIdSetIterator);
826+
assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());
827+
828+
jniServiceMockedStatic.verify(
829+
() -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()),
830+
times(1)
831+
);
832+
833+
final List<Integer> actualDocIds = new ArrayList<>();
834+
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
835+
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
836+
actualDocIds.add(docId);
837+
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
838+
}
839+
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
840+
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
841+
}
842+
761843
private SegmentReader mockSegmentReader() {
762844
Path path = mock(Path.class);
763845

@@ -815,7 +897,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean
815897
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
816898
// scorer will return 2 documents
817899
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
818-
when(reader.maxDoc()).thenReturn(1);
900+
when(reader.maxDoc()).thenReturn(2);
819901
final Bits liveDocsBits = mock(Bits.class);
820902
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
821903
when(liveDocsBits.get(filterDocId)).thenReturn(true);
@@ -891,6 +973,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() {
891973
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
892974
final SegmentReader reader = mock(SegmentReader.class);
893975
when(leafReaderContext.reader()).thenReturn(reader);
976+
when(reader.maxDoc()).thenReturn(1);
894977

895978
final FSDirectory directory = mock(FSDirectory.class);
896979
when(reader.directory()).thenReturn(directory);
@@ -968,7 +1051,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
9681051
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
9691052
// scorer will return 2 documents
9701053
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
971-
when(reader.maxDoc()).thenReturn(1);
1054+
when(reader.maxDoc()).thenReturn(2);
9721055
final Bits liveDocsBits = mock(Bits.class);
9731056
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
9741057
when(liveDocsBits.get(filterDocId)).thenReturn(true);
@@ -1168,6 +1251,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
11681251
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
11691252
final SegmentReader reader = mock(SegmentReader.class);
11701253
when(leafReaderContext.reader()).thenReturn(reader);
1254+
when(reader.maxDoc()).thenReturn(1);
11711255

11721256
final Weight filterQueryWeight = mock(Weight.class);
11731257
final Scorer filterScorer = mock(Scorer.class);
@@ -1202,7 +1286,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
12021286
// We will have 0, 1 for filteredIds and 2 will be the parent id for both of them
12031287
final Scorer filterScorer = mock(Scorer.class);
12041288
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2));
1205-
when(reader.maxDoc()).thenReturn(2);
1289+
when(reader.maxDoc()).thenReturn(3);
12061290

12071291
// Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result
12081292
final List<float[]> vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f });
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.integ;
7+
8+
import com.google.common.collect.ImmutableMap;
9+
import lombok.SneakyThrows;
10+
import lombok.extern.log4j.Log4j2;
11+
import org.apache.http.util.EntityUtils;
12+
import org.opensearch.client.Response;
13+
import org.opensearch.common.settings.Settings;
14+
import org.opensearch.knn.KNNJsonQueryBuilder;
15+
import org.opensearch.knn.KNNRestTestCase;
16+
import org.opensearch.knn.index.KNNSettings;
17+
import java.util.List;
18+
19+
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
20+
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
21+
22+
@Log4j2
23+
public class FilteredSearchANNSearchIT extends KNNRestTestCase {
24+
@SneakyThrows
25+
public void testFilteredSearchWithFaissHnsw_whenFiltersMatchAllDocs_thenReturnCorrectResults() {
26+
String filterFieldName = "color";
27+
final int expectResultSize = randomIntBetween(1, 3);
28+
final String filterValue = "red";
29+
createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), createKnnIndexMapping(FIELD_NAME, 3, METHOD_HNSW, FAISS_NAME));
30+
31+
// ingest 4 vector docs into the index with the same field {"color": "red"}
32+
for (int i = 0; i < 4; i++) {
33+
addKnnDocWithAttributes(String.valueOf(i), new float[] { i, i, i }, ImmutableMap.of(filterFieldName, filterValue));
34+
}
35+
36+
refreshIndex(INDEX_NAME);
37+
forceMergeKnnIndex(INDEX_NAME);
38+
39+
updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0));
40+
41+
Float[] queryVector = { 3f, 3f, 3f };
42+
// All docs in one segment will match the filters value
43+
String query = KNNJsonQueryBuilder.builder()
44+
.fieldName(FIELD_NAME)
45+
.vector(queryVector)
46+
.k(expectResultSize)
47+
.filterFieldName(filterFieldName)
48+
.filterValue(filterValue)
49+
.build()
50+
.getQueryString();
51+
Response response = searchKNNIndex(INDEX_NAME, query, expectResultSize);
52+
String entity = EntityUtils.toString(response.getEntity());
53+
List<String> docIds = parseIds(entity);
54+
assertEquals(expectResultSize, docIds.size());
55+
assertEquals(expectResultSize, parseTotalSearchHits(entity));
56+
}
57+
}

0 commit comments

Comments
 (0)