Skip to content

Commit 62c8c1f

Browse files
committed
Support hybrid query with neural highlighter
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent de2d9a4 commit 62c8c1f

File tree

7 files changed

+208
-20
lines changed

7 files changed

+208
-20
lines changed

src/main/java/org/opensearch/neuralsearch/highlight/extractor/BooleanQueryTextExtractor.java

+5-7
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ public BooleanQueryTextExtractor(QueryTextExtractorRegistry registry) {
2424

2525
@Override
2626
public String extractQueryText(Query query, String fieldName) {
27-
if (!(query instanceof BooleanQuery booleanQuery)) {
28-
return "";
29-
}
27+
validateQueryType(query, BooleanQuery.class);
28+
BooleanQuery booleanQuery = (BooleanQuery) query;
3029

3130
StringBuilder sb = new StringBuilder();
3231

@@ -38,15 +37,14 @@ public String extractQueryText(Query query, String fieldName) {
3837

3938
try {
4039
String clauseText = registry.extractQueryText(clause.query(), fieldName);
41-
if (!clauseText.isEmpty()) {
42-
if (sb.length() > 0) {
40+
if (clauseText.isEmpty() == false) {
41+
if (sb.isEmpty() == false) {
4342
sb.append(" ");
4443
}
4544
sb.append(clauseText);
4645
}
4746
} catch (IllegalArgumentException e) {
48-
// If a clause has empty query text, just skip it
49-
log.debug("Skipping clause with empty query text: {}", clause.query());
47+
log.warn(String.format("Failed to extract text from clause %s: %s", clause, e.getMessage()), e);
5048
}
5149
}
5250

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.highlight.extractor;
6+
7+
import org.apache.lucene.search.Query;
8+
import org.opensearch.neuralsearch.query.HybridQuery;
9+
10+
import java.util.LinkedHashSet;
11+
import java.util.Set;
12+
13+
/**
14+
* Extractor for hybrid queries that combines text from all sub-queries
15+
*/
16+
public class HybridQueryTextExtractor implements QueryTextExtractor {
17+
18+
private final QueryTextExtractorRegistry registry;
19+
20+
public HybridQueryTextExtractor(QueryTextExtractorRegistry registry) {
21+
this.registry = registry;
22+
}
23+
24+
@Override
25+
public String extractQueryText(Query query, String fieldName) {
26+
validateQueryType(query, HybridQuery.class);
27+
HybridQuery hybridQuery = (HybridQuery) query;
28+
29+
// Create a set to avoid duplicates while maintaining order
30+
Set<String> queryTexts = new LinkedHashSet<>();
31+
32+
// Extract text from each sub-query
33+
for (Query subQuery : hybridQuery.getSubQueries()) {
34+
String extractedText = registry.extractQueryText(subQuery, fieldName);
35+
if (!extractedText.isEmpty()) {
36+
queryTexts.add(extractedText);
37+
}
38+
}
39+
40+
// Join with spaces
41+
return String.join(" ", queryTexts).trim();
42+
}
43+
}

src/main/java/org/opensearch/neuralsearch/highlight/extractor/NeuralQueryTextExtractor.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ public class NeuralQueryTextExtractor implements QueryTextExtractor {
1414

1515
@Override
1616
public String extractQueryText(Query query, String fieldName) {
17-
if (query instanceof NeuralKNNQuery neuralQuery) {
18-
return neuralQuery.getOriginalQueryText();
19-
}
20-
return "";
17+
validateQueryType(query, NeuralKNNQuery.class);
18+
NeuralKNNQuery neuralQuery = (NeuralKNNQuery) query;
19+
return neuralQuery.getOriginalQueryText();
2120
}
2221
}

src/main/java/org/opensearch/neuralsearch/highlight/extractor/QueryTextExtractor.java

+15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,21 @@
1010
* Interface for extracting query text from different query types
1111
*/
1212
public interface QueryTextExtractor {
13+
/**
14+
* Validates if the query is of the expected type
15+
*
16+
* @param query The query to validate
17+
* @param expectedType The expected query type
18+
* @throws IllegalArgumentException if the query is not of the expected type
19+
*/
20+
default void validateQueryType(Query query, Class<? extends Query> expectedType) {
21+
if (expectedType.isInstance(query) == false) {
22+
throw new IllegalArgumentException(
23+
String.format("Expected %s but got %s", expectedType.getSimpleName(), query.getClass().getSimpleName())
24+
);
25+
}
26+
}
27+
1328
/**
1429
* Extracts text from a query for highlighting
1530
*

src/main/java/org/opensearch/neuralsearch/highlight/extractor/QueryTextExtractorRegistry.java

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.apache.lucene.search.Query;
99
import org.apache.lucene.search.TermQuery;
1010
import org.opensearch.neuralsearch.query.NeuralKNNQuery;
11+
import org.opensearch.neuralsearch.query.HybridQuery;
1112

1213
import lombok.extern.log4j.Log4j2;
1314

@@ -35,6 +36,7 @@ public QueryTextExtractorRegistry() {
3536
private void initialize() {
3637
register(NeuralKNNQuery.class, new NeuralQueryTextExtractor());
3738
register(TermQuery.class, new TermQueryTextExtractor());
39+
register(HybridQuery.class, new HybridQueryTextExtractor(this));
3840

3941
// BooleanQueryTextExtractor needs a reference to this registry
4042
// so we need to register it after creating the registry instance

src/main/java/org/opensearch/neuralsearch/highlight/extractor/TermQueryTextExtractor.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ public class TermQueryTextExtractor implements QueryTextExtractor {
1515

1616
@Override
1717
public String extractQueryText(Query query, String fieldName) {
18-
if (!(query instanceof TermQuery termQuery)) {
19-
return "";
20-
}
18+
validateQueryType(query, TermQuery.class);
19+
TermQuery termQuery = (TermQuery) query;
2120

2221
Term term = termQuery.getTerm();
2322
// Only include terms from the field we're highlighting

src/test/java/org/opensearch/neuralsearch/highlight/QueryTextExtractorTests.java

+138-6
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,21 @@
77
import org.apache.lucene.index.Term;
88
import org.apache.lucene.search.BooleanClause;
99
import org.apache.lucene.search.BooleanQuery;
10+
import org.apache.lucene.search.Query;
1011
import org.apache.lucene.search.TermQuery;
1112
import org.opensearch.neuralsearch.highlight.extractor.BooleanQueryTextExtractor;
1213
import org.opensearch.neuralsearch.highlight.extractor.NeuralQueryTextExtractor;
1314
import org.opensearch.neuralsearch.highlight.extractor.QueryTextExtractorRegistry;
1415
import org.opensearch.neuralsearch.highlight.extractor.TermQueryTextExtractor;
16+
import org.opensearch.neuralsearch.highlight.extractor.HybridQueryTextExtractor;
1517
import org.opensearch.neuralsearch.query.NeuralKNNQuery;
18+
import org.opensearch.neuralsearch.query.HybridQuery;
19+
import org.opensearch.neuralsearch.query.HybridQueryContext;
1620
import org.opensearch.test.OpenSearchTestCase;
1721

22+
import java.util.ArrayList;
23+
import java.util.List;
24+
1825
import static org.mockito.Mockito.mock;
1926
import static org.mockito.Mockito.when;
2027

@@ -48,8 +55,14 @@ public void testTermQueryExtractor() {
4855

4956
// Test with non-TermQuery
5057
BooleanQuery booleanQuery = new BooleanQuery.Builder().build();
51-
result = extractor.extractQueryText(booleanQuery, "content");
52-
assertEquals("Should return empty string for non-TermQuery", "", result);
58+
IllegalArgumentException exception = expectThrows(
59+
IllegalArgumentException.class,
60+
() -> extractor.extractQueryText(booleanQuery, "content")
61+
);
62+
assertTrue(
63+
"Should throw IllegalArgumentException with correct message",
64+
exception.getMessage().contains("Expected TermQuery but got BooleanQuery")
65+
);
5366
}
5467

5568
/**
@@ -67,8 +80,14 @@ public void testNeuralQueryExtractor() {
6780

6881
// Test with non-NeuralKNNQuery
6982
TermQuery termQuery = new TermQuery(new Term("content", "term"));
70-
result = extractor.extractQueryText(termQuery, "content");
71-
assertEquals("Should return empty string for non-NeuralKNNQuery", "", result);
83+
IllegalArgumentException exception = expectThrows(
84+
IllegalArgumentException.class,
85+
() -> extractor.extractQueryText(termQuery, "content")
86+
);
87+
assertTrue(
88+
"Should throw IllegalArgumentException with correct message",
89+
exception.getMessage().contains("Expected NeuralKNNQuery but got TermQuery")
90+
);
7291
}
7392

7493
/**
@@ -91,8 +110,14 @@ public void testBooleanQueryExtractor() {
91110

92111
// Test with non-BooleanQuery
93112
TermQuery termQuery = new TermQuery(new Term("content", "term"));
94-
result = extractor.extractQueryText(termQuery, "content");
95-
assertEquals("Should return empty string for non-BooleanQuery", "", result);
113+
IllegalArgumentException exception = expectThrows(
114+
IllegalArgumentException.class,
115+
() -> extractor.extractQueryText(termQuery, "content")
116+
);
117+
assertTrue(
118+
"Should throw IllegalArgumentException with correct message",
119+
exception.getMessage().contains("Expected BooleanQuery but got TermQuery")
120+
);
96121

97122
// Test with empty clauses
98123
BooleanQuery emptyQuery = new BooleanQuery.Builder().build();
@@ -160,4 +185,111 @@ public void visit(org.apache.lucene.search.QueryVisitor visitor) {
160185
result = registry.extractQueryText(customQuery, "content");
161186
assertEquals("Should use custom extractor", "custom-extracted", result);
162187
}
188+
189+
/**
190+
* Tests the HybridQueryTextExtractor
191+
*/
192+
public void testHybridQueryExtractor() {
193+
// Create a hybrid query with multiple sub-queries
194+
List<Query> subQueries = new ArrayList<>();
195+
196+
// Add a term query
197+
TermQuery termQuery = new TermQuery(new Term("content", "machine"));
198+
subQueries.add(termQuery);
199+
200+
// Add a boolean query (match query)
201+
BooleanQuery.Builder boolBuilder = new BooleanQuery.Builder();
202+
boolBuilder.add(new TermQuery(new Term("content", "learning")), BooleanClause.Occur.MUST);
203+
subQueries.add(boolBuilder.build());
204+
205+
// Add a neural query
206+
NeuralKNNQuery neuralQuery = mock(NeuralKNNQuery.class);
207+
when(neuralQuery.getOriginalQueryText()).thenReturn("AI systems that can learn");
208+
subQueries.add(neuralQuery);
209+
210+
// Create the hybrid query
211+
HybridQuery hybridQuery = new HybridQuery(subQueries, HybridQueryContext.builder().build());
212+
213+
// Test extraction
214+
String result = registry.extractQueryText(hybridQuery, "content");
215+
assertEquals("Should combine all query texts correctly", "machine learning AI systems that can learn", result);
216+
217+
// Test with non-HybridQuery
218+
TermQuery nonHybridQuery = new TermQuery(new Term("content", "term"));
219+
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> {
220+
HybridQueryTextExtractor extractor = new HybridQueryTextExtractor(registry);
221+
extractor.extractQueryText(nonHybridQuery, "content");
222+
});
223+
assertTrue(
224+
"Should throw IllegalArgumentException with correct message",
225+
exception.getMessage().contains("Expected HybridQuery but got TermQuery")
226+
);
227+
}
228+
229+
/**
230+
* Tests the HybridQueryTextExtractor with empty or invalid sub-queries
231+
*/
232+
public void testHybridQueryExtractorWithEmptyQueries() {
233+
// Create a hybrid query with no valid text
234+
List<Query> subQueries = new ArrayList<>();
235+
236+
// Add a term query with non-matching field
237+
TermQuery termQuery = new TermQuery(new Term("title", "machine"));
238+
subQueries.add(termQuery);
239+
240+
// Create the hybrid query
241+
HybridQuery hybridQuery = new HybridQuery(subQueries, HybridQueryContext.builder().build());
242+
243+
// Test extraction
244+
String result = registry.extractQueryText(hybridQuery, "content");
245+
assertEquals("Should return empty string for no valid text", "", result);
246+
}
247+
248+
/**
249+
* Tests the HybridQueryTextExtractor with duplicate texts
250+
*/
251+
public void testHybridQueryExtractorWithDuplicates() {
252+
List<Query> subQueries = new ArrayList<>();
253+
254+
// Add two term queries with the same text
255+
subQueries.add(new TermQuery(new Term("content", "duplicate")));
256+
subQueries.add(new TermQuery(new Term("content", "duplicate")));
257+
258+
// Add a neural query with overlapping text
259+
NeuralKNNQuery neuralQuery = mock(NeuralKNNQuery.class);
260+
when(neuralQuery.getOriginalQueryText()).thenReturn("duplicate text");
261+
subQueries.add(neuralQuery);
262+
263+
// Create the hybrid query
264+
HybridQuery hybridQuery = new HybridQuery(subQueries, HybridQueryContext.builder().build());
265+
266+
// Test extraction
267+
String result = registry.extractQueryText(hybridQuery, "content");
268+
assertEquals("Should deduplicate terms", "duplicate duplicate text", result);
269+
}
270+
271+
/**
272+
* Tests the HybridQueryTextExtractor with nested queries
273+
*/
274+
public void testHybridQueryExtractorWithNestedQueries() {
275+
List<Query> subQueries = new ArrayList<>();
276+
277+
// Create a boolean query with multiple terms
278+
BooleanQuery.Builder boolBuilder = new BooleanQuery.Builder();
279+
boolBuilder.add(new TermQuery(new Term("content", "nested")), BooleanClause.Occur.MUST);
280+
boolBuilder.add(new TermQuery(new Term("content", "terms")), BooleanClause.Occur.MUST);
281+
subQueries.add(boolBuilder.build());
282+
283+
// Add a neural query
284+
NeuralKNNQuery neuralQuery = mock(NeuralKNNQuery.class);
285+
when(neuralQuery.getOriginalQueryText()).thenReturn("neural text");
286+
subQueries.add(neuralQuery);
287+
288+
// Create the hybrid query
289+
HybridQuery hybridQuery = new HybridQuery(subQueries, HybridQueryContext.builder().build());
290+
291+
// Test extraction
292+
String result = registry.extractQueryText(hybridQuery, "content");
293+
assertEquals("Should handle nested queries correctly", "nested terms neural text", result);
294+
}
163295
}

0 commit comments

Comments
 (0)