forked from opensearch-project/neural-search
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHybridQueryWeight.java
207 lines (191 loc) · 7.74 KB
/
HybridQueryWeight.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Matches;
import org.apache.lucene.search.MatchesUtils;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.neuralsearch.executors.HybridQueryExecutor;
import org.opensearch.neuralsearch.executors.HybridQueryExecutorCollector;
import org.opensearch.neuralsearch.executors.HybridQueryScoreSupplierCollectorManager;
import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES;
/**
* Calculates query weights and build query scorers for hybrid query.
*/
public final class HybridQueryWeight extends Weight {
// The Weights for our subqueries, in 1-1 correspondence
@Getter(AccessLevel.PACKAGE)
private final List<Weight> weights;
private final ScoreMode scoreMode;
/**
* Construct the Weight for this Query searched by searcher. Recursively construct subquery weights.
*/
public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
super(hybridQuery);
weights = hybridQuery.getSubQueries().stream().map(q -> {
try {
return searcher.createWeight(q, scoreMode, boost);
} catch (IOException e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList());
this.scoreMode = scoreMode;
}
/**
* Returns Matches for a specific document, or null if the document does not match the parent query
*
* @param context the reader's context to create the {@link Matches} for
* @param doc the document's id relative to the given context's reader
* @return
* @throws IOException
*/
@Override
public Matches matches(LeafReaderContext context, int doc) throws IOException {
List<Matches> mis = weights.stream().map(weight -> {
try {
return weight.matches(context, doc);
} catch (IOException e) {
throw new RuntimeException(e);
}
}).filter(Objects::nonNull).collect(Collectors.toList());
return MatchesUtils.fromSubMatches(mis);
}
/**
* Returns {@link HybridScorerSupplier} which contains list of {@link ScorerSupplier} from its
* sub queries. Here, add score supplier from individual sub query is parallelized and finally
* {@link HybridScorerSupplier} is created with list of {@link ScorerSupplier}
*/
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
HybridQueryScoreSupplierCollectorManager manager = new HybridQueryScoreSupplierCollectorManager(context);
List<Callable<Void>> scoreSupplierTasks = new ArrayList<>();
List<HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier>> collectors = new ArrayList<>();
for (Weight weight : weights) {
HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier> collector = manager.newCollector();
collectors.add(collector);
scoreSupplierTasks.add(() -> addScoreSupplier(weight, collector));
}
HybridQueryExecutor.getExecutor().invokeAll(scoreSupplierTasks);
final List<ScorerSupplier> scorerSuppliers = manager.mergeScoreSuppliers(collectors);
if (scorerSuppliers.isEmpty()) {
return null;
}
return new HybridScorerSupplier(scorerSuppliers, this, scoreMode);
}
private Void addScoreSupplier(Weight weight, HybridQueryExecutorCollector<LeafReaderContext, ScorerSupplier> collector) {
collector.collect(leafReaderContext -> {
try {
return weight.scorerSupplier(leafReaderContext);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
return null;
}
/**
* Check if weight object can be cached
*
* @param ctx
* @return true if the object can be cached against a given leaf
*/
@Override
public boolean isCacheable(LeafReaderContext ctx) {
if (weights.size() > MAX_NUMBER_OF_SUB_QUERIES) {
// this situation should never happen, but in case it do such query will not be cached
return false;
}
return weights.stream().allMatch(w -> w.isCacheable(ctx));
}
/**
* Returns a shard level {@link Explanation} that describes how the weight and scoring are calculated.
* @param context the readers context to create the {@link Explanation} for.
* @param doc the document's id relative to the given context's reader
* @return shard level {@link Explanation}, each sub-query explanation is a single nested element
* @throws IOException
*/
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
boolean match = false;
double max = 0;
List<Explanation> subsOnNoMatch = new ArrayList<>();
List<Explanation> subsOnMatch = new ArrayList<>();
for (Weight wt : weights) {
Explanation e = wt.explain(context, doc);
if (e.isMatch()) {
match = true;
double score = e.getValue().doubleValue();
max = Math.max(max, score);
subsOnMatch.add(e);
} else {
if (!match) {
subsOnNoMatch.add(e);
}
subsOnMatch.add(e);
}
}
if (match) {
final String desc = "combined score of:";
return Explanation.match(max, desc, subsOnMatch);
} else {
return Explanation.noMatch("no matching clause", subsOnNoMatch);
}
}
@RequiredArgsConstructor
static class HybridScorerSupplier extends ScorerSupplier {
private long cost = -1;
private final List<ScorerSupplier> scorerSuppliers;
private final Weight weight;
private final ScoreMode scoreMode;
@Override
public Scorer get(long leadCost) throws IOException {
List<Scorer> tScorers = new ArrayList<>();
for (ScorerSupplier ss : scorerSuppliers) {
if (Objects.nonNull(ss)) {
tScorers.add(ss.get(leadCost));
} else {
tScorers.add(null);
}
}
return new HybridQueryScorer(weight, tScorers, scoreMode);
}
@Override
public long cost() {
if (cost == -1) {
long cost = 0;
for (ScorerSupplier ss : scorerSuppliers) {
if (Objects.nonNull(ss)) {
cost += ss.cost();
}
}
this.cost = cost;
}
return cost;
}
@Override
public void setTopLevelScoringClause() throws IOException {
for (ScorerSupplier ss : scorerSuppliers) {
// sub scorers need to be able to skip too as calls to setMinCompetitiveScore get
// propagated
if (Objects.nonNull(ss)) {
ss.setTopLevelScoringClause();
}
}
}
};
}