Skip to content

Commit 8c743ec

Browse files
Fix explain exception in hybrid queries with partial subquery matches (#1123)
* Fixed exception for explain in hybrid query when partial match in subqueries Signed-off-by: Martin Gaievski <gaievski@amazon.com>
1 parent 3dbdcba commit 8c743ec

14 files changed

+893
-88
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
3131
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
3232
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
3333
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
34+
- Fix explain exception in hybrid queries with partial subquery matches ([#1123](https://github.com/opensearch-project/neural-search/pull/1123))
3435
- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132))
3536
### Infrastructure
3637
### Documentation

src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java

+37-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import lombok.AllArgsConstructor;
88
import lombok.Getter;
9+
import lombok.extern.log4j.Log4j2;
10+
import org.apache.commons.lang3.tuple.Pair;
911
import org.apache.lucene.search.Explanation;
1012
import org.opensearch.action.search.SearchRequest;
1113
import org.opensearch.action.search.SearchResponse;
@@ -21,6 +23,7 @@
2123
import java.util.ArrayList;
2224
import java.util.HashMap;
2325
import java.util.List;
26+
import java.util.Locale;
2427
import java.util.Map;
2528
import java.util.Objects;
2629

@@ -32,6 +35,7 @@
3235
*/
3336
@Getter
3437
@AllArgsConstructor
38+
@Log4j2
3539
public class ExplanationResponseProcessor implements SearchResponseProcessor {
3640

3741
public static final String TYPE = "hybrid_score_explanation";
@@ -99,16 +103,40 @@ public SearchResponse processResponse(
99103
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
100104
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
101105
// Create normalized explanations for each detail
102-
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
103-
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
104-
normalizedExplanation[i] = Explanation.match(
105-
// normalized score
106-
normalizationExplanation.getScoreDetails().get(i).getKey(),
107-
// description of normalized score
108-
normalizationExplanation.getScoreDetails().get(i).getValue(),
109-
// shard level details
110-
queryLevelExplanation.getDetails()[i]
106+
if (normalizationExplanation.getScoreDetails().size() != queryLevelExplanation.getDetails().length) {
107+
log.error(
108+
String.format(
109+
Locale.ROOT,
110+
"length of query level explanations %d must match length of explanations after normalization %d",
111+
queryLevelExplanation.getDetails().length,
112+
normalizationExplanation.getScoreDetails().size()
113+
)
111114
);
115+
throw new IllegalStateException("mismatch in number of query level explanations and normalization explanations");
116+
}
117+
List<Explanation> normalizedExplanation = new ArrayList<>(queryLevelExplanation.getDetails().length);
118+
int normalizationExplanationIndex = 0;
119+
for (Explanation queryExplanation : queryLevelExplanation.getDetails()) {
120+
// adding only explanations where this hit has matched
121+
if (Float.compare(queryExplanation.getValue().floatValue(), 0.0f) > 0) {
122+
Pair<Float, String> normalizedScoreDetails = normalizationExplanation.getScoreDetails()
123+
.get(normalizationExplanationIndex);
124+
if (Objects.isNull(normalizedScoreDetails)) {
125+
throw new IllegalStateException("normalized score details must not be null");
126+
}
127+
normalizedExplanation.add(
128+
Explanation.match(
129+
// normalized score
130+
normalizedScoreDetails.getKey(),
131+
// description of normalized score
132+
normalizedScoreDetails.getValue(),
133+
// shard level details
134+
queryExplanation
135+
)
136+
);
137+
}
138+
// we increment index in all cases, scores in query explanation can be 0.0
139+
normalizationExplanationIndex++;
112140
}
113141
// Create and set final explanation combining all components
114142
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();

src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java

+11-4
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,19 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs>
7575
continue;
7676
}
7777
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
78-
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
79-
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
78+
int numberOfSubQueries = topDocsPerSubQuery.size();
79+
for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; subQueryIndex++) {
80+
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
8081
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
8182
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
82-
float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j));
83-
normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore);
83+
float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(subQueryIndex));
84+
ScoreNormalizationUtil.setNormalizedScore(
85+
normalizedScores,
86+
docIdAtSearchShard,
87+
subQueryIndex,
88+
numberOfSubQueries,
89+
normalizedScore
90+
);
8491
scoreDoc.score = normalizedScore;
8592
}
8693
}

src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java

+12-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
*/
55
package org.opensearch.neuralsearch.processor.normalization;
66

7-
import java.util.ArrayList;
87
import java.util.Arrays;
98
import java.util.HashMap;
109
import java.util.List;
@@ -92,16 +91,23 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
9291
continue;
9392
}
9493
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
95-
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
96-
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
94+
int numberOfSubQueries = topDocsPerSubQuery.size();
95+
for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; subQueryIndex++) {
96+
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
9797
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
9898
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
9999
float normalizedScore = normalizeSingleScore(
100100
scoreDoc.score,
101-
minMaxScores.getMinScoresPerSubquery()[j],
102-
minMaxScores.getMaxScoresPerSubquery()[j]
101+
minMaxScores.getMinScoresPerSubquery()[subQueryIndex],
102+
minMaxScores.getMaxScoresPerSubquery()[subQueryIndex]
103+
);
104+
ScoreNormalizationUtil.setNormalizedScore(
105+
normalizedScores,
106+
docIdAtSearchShard,
107+
subQueryIndex,
108+
numberOfSubQueries,
109+
normalizedScore
103110
);
104-
normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore);
105111
scoreDoc.score = normalizedScore;
106112
}
107113
}

src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java

+39-17
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,24 @@
66

77
import java.math.BigDecimal;
88
import java.math.RoundingMode;
9-
import java.util.ArrayList;
9+
import java.util.Arrays;
1010
import java.util.HashMap;
1111
import java.util.List;
1212
import java.util.Map;
1313
import java.util.Objects;
1414
import java.util.Locale;
1515
import java.util.Set;
16-
import java.util.function.BiConsumer;
17-
import java.util.stream.IntStream;
1816

1917
import org.apache.commons.lang3.Range;
2018
import org.apache.commons.lang3.math.NumberUtils;
19+
import org.apache.lucene.search.ScoreDoc;
20+
import org.apache.lucene.search.TopDocs;
21+
import org.opensearch.common.TriConsumer;
2122
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
2223

2324
import lombok.ToString;
2425
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
26+
import org.opensearch.neuralsearch.processor.SearchShard;
2527
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
2628
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
2729
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
@@ -65,7 +67,7 @@ public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNo
6567
public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
6668
final List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
6769
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
68-
processTopDocs(compoundQueryTopDocs, (docId, score) -> {});
70+
processTopDocs(compoundQueryTopDocs, (docId, score, subQueryIndex) -> {});
6971
}
7072
}
7173

@@ -79,31 +81,51 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs>
7981
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();
8082

8183
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
84+
if (Objects.isNull(compoundQueryTopDocs)) {
85+
continue;
86+
}
87+
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
88+
int numberOfSubQueries = topDocsPerSubQuery.size();
8289
processTopDocs(
8390
compoundQueryTopDocs,
84-
(docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score)
91+
(docId, score, subQueryIndex) -> ScoreNormalizationUtil.setNormalizedScore(
92+
normalizedScores,
93+
docId,
94+
subQueryIndex,
95+
numberOfSubQueries,
96+
score
97+
)
8598
);
8699
}
87100

88101
return getDocIdAtQueryForNormalization(normalizedScores, this);
89102
}
90103

91-
private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer<DocIdAtSearchShard, Float> scoreProcessor) {
104+
private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor) {
92105
if (Objects.isNull(compoundQueryTopDocs)) {
93106
return;
94107
}
95108

96-
compoundQueryTopDocs.getTopDocs().forEach(topDocs -> {
97-
IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> {
98-
float normalizedScore = calculateNormalizedScore(position);
99-
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(
100-
topDocs.scoreDocs[position].doc,
101-
compoundQueryTopDocs.getSearchShard()
102-
);
103-
scoreProcessor.accept(docIdAtSearchShard, normalizedScore);
104-
topDocs.scoreDocs[position].score = normalizedScore;
105-
});
106-
});
109+
List<TopDocs> topDocsList = compoundQueryTopDocs.getTopDocs();
110+
SearchShard searchShard = compoundQueryTopDocs.getSearchShard();
111+
112+
for (int topDocsIndex = 0; topDocsIndex < topDocsList.size(); topDocsIndex++) {
113+
processTopDocsEntry(topDocsList.get(topDocsIndex), searchShard, topDocsIndex, scoreProcessor);
114+
}
115+
}
116+
117+
private void processTopDocsEntry(
118+
TopDocs topDocs,
119+
SearchShard searchShard,
120+
int topDocsIndex,
121+
TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor
122+
) {
123+
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
124+
float normalizedScore = calculateNormalizedScore(Arrays.asList(topDocs.scoreDocs).indexOf(scoreDoc));
125+
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, searchShard);
126+
scoreProcessor.apply(docIdAtSearchShard, normalizedScore, topDocsIndex);
127+
scoreDoc.score = normalizedScore;
128+
}
107129
}
108130

109131
private float calculateNormalizedScore(int position) {

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java

+28
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
package org.opensearch.neuralsearch.processor.normalization;
66

77
import lombok.extern.log4j.Log4j2;
8+
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
89

10+
import java.util.ArrayList;
911
import java.util.List;
1012
import java.util.Locale;
1113
import java.util.Map;
@@ -54,4 +56,30 @@ public void validateParams(final Map<String, Object> actualParams, final Set<Str
5456
}
5557
}
5658
}
59+
60+
/**
61+
* Sets a normalized score for a specific document at a specific subquery index
62+
*
63+
* @param normalizedScores map of document IDs to their list of scores
64+
* @param docIdAtSearchShard document ID
65+
* @param subQueryIndex index of the subquery
66+
* @param normalizedScore normalized score to set
67+
*/
68+
public static void setNormalizedScore(
69+
Map<DocIdAtSearchShard, List<Float>> normalizedScores,
70+
DocIdAtSearchShard docIdAtSearchShard,
71+
int subQueryIndex,
72+
int numberOfSubQueries,
73+
float normalizedScore
74+
) {
75+
List<Float> scores = normalizedScores.get(docIdAtSearchShard);
76+
if (Objects.isNull(scores)) {
77+
scores = new ArrayList<>(numberOfSubQueries);
78+
for (int i = 0; i < numberOfSubQueries; i++) {
79+
scores.add(0.0f);
80+
}
81+
normalizedScores.put(docIdAtSearchShard, scores);
82+
}
83+
scores.set(subQueryIndex, normalizedScore);
84+
}
5785
}

src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java

+9-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import java.util.concurrent.Callable;
1212
import java.util.stream.Collectors;
1313

14+
import lombok.AccessLevel;
15+
import lombok.Getter;
1416
import lombok.RequiredArgsConstructor;
1517
import org.apache.lucene.index.LeafReaderContext;
1618
import org.apache.lucene.search.Explanation;
@@ -33,6 +35,7 @@
3335
public final class HybridQueryWeight extends Weight {
3436

3537
// The Weights for our subqueries, in 1-1 correspondence
38+
@Getter(AccessLevel.PACKAGE)
3639
private final List<Weight> weights;
3740

3841
private final ScoreMode scoreMode;
@@ -157,10 +160,13 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
157160
if (e.isMatch()) {
158161
match = true;
159162
double score = e.getValue().doubleValue();
160-
subsOnMatch.add(e);
161163
max = Math.max(max, score);
162-
} else if (!match) {
163-
subsOnNoMatch.add(e);
164+
subsOnMatch.add(e);
165+
} else {
166+
if (!match) {
167+
subsOnNoMatch.add(e);
168+
}
169+
subsOnMatch.add(e);
164170
}
165171
}
166172
if (match) {

0 commit comments

Comments
 (0)