Skip to content

Commit 806042c

Browse files
authored
Optimize the max score tracking in the Query Phase of Hybrid Search (#765)
1 parent afd1215 commit 806042c

File tree

4 files changed

+35
-14
lines changed

4 files changed

+35
-14
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
2020
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
2121
- Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
22+
- Optimize max score calculation in the Query Phase of the Hybrid Search ([765](https://github.com/opensearch-project/neural-search/pull/765))
2223
### Bug Fixes
2324
- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
2425
### Infrastructure

src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ public class HybridTopScoreDocCollector implements Collector {
3939
private int[] collectedHitsPerSubQuery;
4040
private final int numOfHits;
4141
private PriorityQueue<ScoreDoc>[] compoundScores;
42+
@Getter
43+
private float maxScore = 0.0f;
4244

4345
public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) {
4446
numOfHits = numHits;
@@ -115,6 +117,7 @@ public void collect(int doc) throws IOException {
115117
collectedHitsPerSubQuery[i]++;
116118
PriorityQueue<ScoreDoc> pq = compoundScores[i];
117119
ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score);
120+
maxScore = Math.max(currentDoc.score, maxScore);
118121
// this way we're inserting into heap and do nothing else unless we reach the capacity
119122
// after that we pull out the lowest score element on each insert
120123
pq.insertWithOverflow(currentDoc);

src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java

+1-14
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
146146
getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
147147
topDocs
148148
);
149-
float maxScore = getMaxScore(topDocs);
150-
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
149+
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
151150
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
152151
}
153152
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
@@ -212,18 +211,6 @@ private TotalHits getTotalHits(
212211
return new TotalHits(maxTotalHits, relation);
213212
}
214213

215-
private float getMaxScore(final List<TopDocs> topDocs) {
216-
if (topDocs.isEmpty()) {
217-
return 0.0f;
218-
} else {
219-
return topDocs.stream()
220-
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
221-
.map(scoreDoc -> scoreDoc.score)
222-
.max(Float::compare)
223-
.get();
224-
}
225-
}
226-
227214
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
228215
return sortAndFormats == null ? null : sortAndFormats.formats;
229216
}

src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java

+30
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,36 @@ public void testTotalHits_whenResultSizeIsLessThenDefaultSize_thenSuccessful() {
206206
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
207207
}
208208

209+
@SneakyThrows
210+
public void testMaxScoreCalculation_whenMaxScoreIsTrackedAtCollectorLevel_thenSuccessful() {
211+
initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);
212+
TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
213+
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
214+
TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);
215+
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
216+
boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3);
217+
218+
HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
219+
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
220+
hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder);
221+
Map<String, Object> searchResponseAsMap = search(
222+
TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
223+
hybridQueryBuilderNeuralThenTerm,
224+
null,
225+
10,
226+
null
227+
);
228+
229+
double maxScore = getMaxScore(searchResponseAsMap).get();
230+
List<Map<String, Object>> hits = getNestedHits(searchResponseAsMap);
231+
double maxScoreExpected = 0.0;
232+
for (Map<String, Object> hit : hits) {
233+
double score = (double) hit.get("_score");
234+
maxScoreExpected = Math.max(score, maxScoreExpected);
235+
}
236+
assertEquals(maxScoreExpected, maxScore, 0.0000001);
237+
}
238+
209239
/**
210240
* Tests complex query with multiple nested sub-queries, where some sub-queries are same
211241
* {

0 commit comments

Comments
 (0)