Skip to content

Commit d375b86

Browse files
committed
Support hybrid query with neural highlighter
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent de2d9a4 commit d375b86

8 files changed

+220
-32
lines changed

src/main/java/org/opensearch/neuralsearch/highlight/extractor/BooleanQueryTextExtractor.java

+5-7
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ public BooleanQueryTextExtractor(QueryTextExtractorRegistry registry) {
2424

2525
@Override
2626
public String extractQueryText(Query query, String fieldName) {
27-
if (!(query instanceof BooleanQuery booleanQuery)) {
28-
return "";
29-
}
27+
validateQueryType(query, BooleanQuery.class);
28+
BooleanQuery booleanQuery = (BooleanQuery) query;
3029

3130
StringBuilder sb = new StringBuilder();
3231

@@ -38,15 +37,14 @@ public String extractQueryText(Query query, String fieldName) {
3837

3938
try {
4039
String clauseText = registry.extractQueryText(clause.query(), fieldName);
41-
if (!clauseText.isEmpty()) {
42-
if (sb.length() > 0) {
40+
if (clauseText.isEmpty() == false) {
41+
if (sb.isEmpty() == false) {
4342
sb.append(" ");
4443
}
4544
sb.append(clauseText);
4645
}
4746
} catch (IllegalArgumentException e) {
48-
// If a clause has empty query text, just skip it
49-
log.debug("Skipping clause with empty query text: {}", clause.query());
47+
log.warn(String.format("Failed to extract text from clause %s: %s", clause, e.getMessage()), e);
5048
}
5149
}
5250

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.highlight.extractor;
6+
7+
import org.apache.lucene.search.Query;
8+
import org.opensearch.neuralsearch.query.HybridQuery;
9+
10+
import java.util.LinkedHashSet;
11+
import java.util.Set;
12+
13+
/**
14+
* Extractor for hybrid queries that combines text from all sub-queries
15+
*/
16+
public class HybridQueryTextExtractor implements QueryTextExtractor {
17+
18+
private final QueryTextExtractorRegistry registry;
19+
20+
public HybridQueryTextExtractor(QueryTextExtractorRegistry registry) {
21+
this.registry = registry;
22+
}
23+
24+
@Override
25+
public String extractQueryText(Query query, String fieldName) {
26+
validateQueryType(query, HybridQuery.class);
27+
HybridQuery hybridQuery = (HybridQuery) query;
28+
29+
// Create a set to avoid duplicates while maintaining order
30+
Set<String> queryTexts = new LinkedHashSet<>();
31+
32+
// Extract text from each sub-query
33+
for (Query subQuery : hybridQuery.getSubQueries()) {
34+
String extractedText = registry.extractQueryText(subQuery, fieldName);
35+
if (!extractedText.isEmpty()) {
36+
queryTexts.add(extractedText);
37+
}
38+
}
39+
40+
// Join with spaces
41+
return String.join(" ", queryTexts).trim();
42+
}
43+
}

src/main/java/org/opensearch/neuralsearch/highlight/extractor/NeuralQueryTextExtractor.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ public class NeuralQueryTextExtractor implements QueryTextExtractor {
1414

1515
@Override
1616
public String extractQueryText(Query query, String fieldName) {
17-
if (query instanceof NeuralKNNQuery neuralQuery) {
18-
return neuralQuery.getOriginalQueryText();
19-
}
20-
return "";
17+
validateQueryType(query, NeuralKNNQuery.class);
18+
NeuralKNNQuery neuralQuery = (NeuralKNNQuery) query;
19+
return neuralQuery.getOriginalQueryText();
2120
}
2221
}

src/main/java/org/opensearch/neuralsearch/highlight/extractor/QueryTextExtractor.java

+15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,21 @@
1010
* Interface for extracting query text from different query types
1111
*/
1212
public interface QueryTextExtractor {
13+
/**
14+
* Validates if the query is of the expected type
15+
*
16+
* @param query The query to validate
17+
* @param expectedType The expected query type
18+
* @throws IllegalArgumentException if the query is not of the expected type
19+
*/
20+
default void validateQueryType(Query query, Class<? extends Query> expectedType) {
21+
if (expectedType.isInstance(query) == false) {
22+
throw new IllegalArgumentException(
23+
String.format("Expected %s but got %s", expectedType.getSimpleName(), query.getClass().getSimpleName())
24+
);
25+
}
26+
}
27+
1328
/**
1429
* Extracts text from a query for highlighting
1530
*

src/main/java/org/opensearch/neuralsearch/highlight/extractor/QueryTextExtractorRegistry.java

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.apache.lucene.search.Query;
99
import org.apache.lucene.search.TermQuery;
1010
import org.opensearch.neuralsearch.query.NeuralKNNQuery;
11+
import org.opensearch.neuralsearch.query.HybridQuery;
1112

1213
import lombok.extern.log4j.Log4j2;
1314

@@ -35,6 +36,7 @@ public QueryTextExtractorRegistry() {
3536
private void initialize() {
3637
register(NeuralKNNQuery.class, new NeuralQueryTextExtractor());
3738
register(TermQuery.class, new TermQueryTextExtractor());
39+
register(HybridQuery.class, new HybridQueryTextExtractor(this));
3840

3941
// BooleanQueryTextExtractor needs a reference to this registry
4042
// so we need to register it after creating the registry instance

src/main/java/org/opensearch/neuralsearch/highlight/extractor/TermQueryTextExtractor.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ public class TermQueryTextExtractor implements QueryTextExtractor {
1515

1616
@Override
1717
public String extractQueryText(Query query, String fieldName) {
18-
if (!(query instanceof TermQuery termQuery)) {
19-
return "";
20-
}
18+
validateQueryType(query, TermQuery.class);
19+
TermQuery termQuery = (TermQuery) query;
2120

2221
Term term = termQuery.getTerm();
2322
// Only include terms from the field we're highlighting

src/test/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterManagerTests.java src/test/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterEngineTests.java

+12-12
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,21 @@
2727
/**
2828
* Tests for the NeuralHighlighterEngine class
2929
*/
30-
public class NeuralHighlighterManagerTests extends OpenSearchTestCase {
30+
public class NeuralHighlighterEngineTests extends OpenSearchTestCase {
3131

3232
private static final String TEST_FIELD = "test_field";
3333
private static final String MODEL_ID = "test_model_id";
3434
private static final String TEST_CONTENT = "This is a test content. For highlighting purposes. With multiple sentences.";
3535
private static final String TEST_QUERY = "test content";
3636

37-
private NeuralHighlighterEngine manager;
37+
private NeuralHighlighterEngine highlighterEngine;
3838
private MLCommonsClientAccessor mlCommonsClientAccessor;
3939

4040
@Override
4141
public void setUp() throws Exception {
4242
super.setUp();
4343
mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
44-
manager = new NeuralHighlighterEngine(mlCommonsClientAccessor);
44+
highlighterEngine = new NeuralHighlighterEngine(mlCommonsClientAccessor);
4545

4646
// Setup default mock behavior
4747
setupDefaultMockBehavior();
@@ -75,22 +75,22 @@ public void testGetModelId() {
7575
Map<String, Object> options = new HashMap<>();
7676
options.put("model_id", MODEL_ID);
7777

78-
String modelId = manager.getModelId(options);
78+
String modelId = highlighterEngine.getModelId(options);
7979
assertEquals("Should extract model ID correctly", MODEL_ID, modelId);
8080
}
8181

8282
public void testGetModelIdMissing() {
8383
Map<String, Object> options = new HashMap<>();
8484

85-
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> manager.getModelId(options));
85+
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> highlighterEngine.getModelId(options));
8686
assertNotNull(exception);
8787
assertTrue(exception.getMessage().contains("Missing required option: model_id"));
8888
}
8989

9090
public void testExtractOriginalQuery() {
9191
// Test with TermQuery
9292
TermQuery termQuery = new TermQuery(new Term(TEST_FIELD, "term"));
93-
String queryText = manager.extractOriginalQuery(termQuery, TEST_FIELD);
93+
String queryText = highlighterEngine.extractOriginalQuery(termQuery, TEST_FIELD);
9494
assertEquals("Should extract term text", "term", queryText);
9595

9696
// Test with BooleanQuery
@@ -99,19 +99,19 @@ public void testExtractOriginalQuery() {
9999
builder.add(new TermQuery(new Term(TEST_FIELD, "term2")), BooleanClause.Occur.MUST);
100100
BooleanQuery booleanQuery = builder.build();
101101

102-
queryText = manager.extractOriginalQuery(booleanQuery, TEST_FIELD);
102+
queryText = highlighterEngine.extractOriginalQuery(booleanQuery, TEST_FIELD);
103103
assertEquals("Should extract combined terms", "term1 term2", queryText);
104104

105105
// Test with NeuralKNNQuery
106106
NeuralKNNQuery neuralQuery = mock(NeuralKNNQuery.class);
107107
when(neuralQuery.getOriginalQueryText()).thenReturn("neural query");
108108

109-
queryText = manager.extractOriginalQuery(neuralQuery, TEST_FIELD);
109+
queryText = highlighterEngine.extractOriginalQuery(neuralQuery, TEST_FIELD);
110110
assertEquals("Should extract neural query text", "neural query", queryText);
111111
}
112112

113113
public void testGetHighlightedSentences() {
114-
String result = manager.getHighlightedSentences(MODEL_ID, TEST_QUERY, TEST_CONTENT);
114+
String result = highlighterEngine.getHighlightedSentences(MODEL_ID, TEST_QUERY, TEST_CONTENT);
115115

116116
assertNotNull("Should return highlighted text", result);
117117
assertTrue("Should contain highlighting tags", result.contains("<em>") && result.contains("</em>"));
@@ -139,7 +139,7 @@ public void testApplyHighlighting() {
139139
highlights.add(resultMap);
140140

141141
String text = "This is a test string";
142-
String result = manager.applyHighlighting(text, highlights);
142+
String result = highlighterEngine.applyHighlighting(text, highlights);
143143

144144
assertEquals("Should apply highlights correctly", "<em>This</em> is <em>a te</em>st string", result);
145145
}
@@ -165,7 +165,7 @@ public void testApplyHighlightingWithOverlaps() {
165165
highlights.add(resultMap);
166166

167167
String text = "This is a test string";
168-
String result = manager.applyHighlighting(text, highlights);
168+
String result = highlighterEngine.applyHighlighting(text, highlights);
169169

170170
// Should merge the overlapping highlights
171171
assertEquals("Should merge overlapping highlights", "<em>This is a </em>test string", result);
@@ -205,7 +205,7 @@ public void testApplyHighlightingWithInvalidPositions() {
205205
highlights.add(resultMap);
206206

207207
String text = "This is a test string";
208-
String result = manager.applyHighlighting(text, highlights);
208+
String result = highlighterEngine.applyHighlighting(text, highlights);
209209

210210
// Should only apply the valid highlight
211211
assertEquals("Should only apply valid highlights", "<em>This</em> is a test string", result);

0 commit comments

Comments
 (0)