Skip to content

Commit de2d9a4

Browse files
committed
Update to use PlainActionFuture to fetch model result and use NeuralHighlighterEngine class name
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent 3ea3562 commit de2d9a4

File tree

6 files changed

+33
-178
lines changed

6 files changed

+33
-178
lines changed

src/main/java/org/opensearch/neuralsearch/highlight/NeuralHighlighter.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
public class NeuralHighlighter implements Highlighter {
2020
public static final String NAME = "neural";
2121

22-
private NeuralHighlighterManager neuralHighlighterManager;
22+
private NeuralHighlighterEngine neuralHighlighterEngine;
2323

2424
public void initialize(MLCommonsClientAccessor mlClient) {
25-
if (neuralHighlighterManager != null) {
26-
throw new IllegalStateException("NeuralHighlighter has already been initialized. Multiple initializations are not permitted.");
25+
if (neuralHighlighterEngine != null) {
26+
throw new IllegalStateException(
27+
"NeuralHighlighterEngine has already been initialized. Multiple initializations are not permitted."
28+
);
2729
}
28-
this.neuralHighlighterManager = new NeuralHighlighterManager(mlClient);
30+
this.neuralHighlighterEngine = new NeuralHighlighterEngine(mlClient);
2931
}
3032

3133
@Override
@@ -41,26 +43,26 @@ public boolean canHighlight(MappedFieldType fieldType) {
4143
*/
4244
@Override
4345
public HighlightField highlight(FieldHighlightContext fieldContext) {
44-
if (neuralHighlighterManager == null) {
46+
if (neuralHighlighterEngine == null) {
4547
throw new IllegalStateException("NeuralHighlighter has not been initialized");
4648
}
4749

4850
// Extract field text
49-
String fieldText = neuralHighlighterManager.getFieldText(fieldContext);
51+
String fieldText = neuralHighlighterEngine.getFieldText(fieldContext);
5052

5153
// Get model ID
52-
String modelId = neuralHighlighterManager.getModelId(fieldContext.field.fieldOptions().options());
54+
String modelId = neuralHighlighterEngine.getModelId(fieldContext.field.fieldOptions().options());
5355

5456
// Try to extract query text
55-
String originalQueryText = neuralHighlighterManager.extractOriginalQuery(fieldContext.query, fieldContext.fieldName);
57+
String originalQueryText = neuralHighlighterEngine.extractOriginalQuery(fieldContext.query, fieldContext.fieldName);
5658

5759
if (originalQueryText == null || originalQueryText.isEmpty()) {
5860
log.warn("No query text found for field {}", fieldContext.fieldName);
5961
return null;
6062
}
6163

6264
// Get highlighted text - allow any exceptions from this call to propagate
63-
String highlightedText = neuralHighlighterManager.getHighlightedSentences(modelId, originalQueryText, fieldText);
65+
String highlightedText = neuralHighlighterEngine.getHighlightedSentences(modelId, originalQueryText, fieldText);
6466

6567
// Create highlight field
6668
Text[] fragments = new Text[] { new Text(highlightedText) };

src/main/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterManager.java src/main/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterEngine.java

+10-27
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,25 @@
88
import org.apache.commons.lang.StringUtils;
99
import org.apache.lucene.search.Query;
1010
import org.opensearch.OpenSearchException;
11-
import org.opensearch.core.action.ActionListener;
1211
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1312
import org.opensearch.neuralsearch.processor.SentenceHighlightingRequest;
1413
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
1514
import org.opensearch.neuralsearch.highlight.extractor.QueryTextExtractorRegistry;
15+
import org.opensearch.action.support.PlainActionFuture;
16+
import lombok.NonNull;
1617

1718
import java.util.ArrayList;
1819
import java.util.Comparator;
1920
import java.util.List;
2021
import java.util.Locale;
2122
import java.util.Map;
2223
import java.util.Objects;
23-
import java.util.concurrent.CountDownLatch;
24-
import java.util.concurrent.atomic.AtomicReference;
2524

2625
/**
2726
* Manager class for neural highlighting operations that handles the core highlighting logic
2827
*/
2928
@Log4j2
30-
public class NeuralHighlighterManager {
29+
public class NeuralHighlighterEngine {
3130
private static final String MODEL_ID_FIELD = "model_id";
3231
private static final String DEFAULT_PRE_TAG = "<em>";
3332
private static final String DEFAULT_POST_TAG = "</em>";
@@ -38,8 +37,8 @@ public class NeuralHighlighterManager {
3837
private final MLCommonsClientAccessor mlCommonsClient;
3938
private final QueryTextExtractorRegistry queryTextExtractorRegistry;
4039

41-
public NeuralHighlighterManager(MLCommonsClientAccessor mlCommonsClient) {
42-
this.mlCommonsClient = Objects.requireNonNull(mlCommonsClient, "ML Commons client cannot be null");
40+
public NeuralHighlighterEngine(@NonNull MLCommonsClientAccessor mlCommonsClient) {
41+
this.mlCommonsClient = mlCommonsClient;
4342
this.queryTextExtractorRegistry = new QueryTextExtractorRegistry();
4443
}
4544

@@ -121,40 +120,24 @@ public String getHighlightedSentences(String modelId, String question, String co
121120
* @return The highlighting results
122121
*/
123122
public List<Map<String, Object>> fetchModelResults(String modelId, String question, String context) {
124-
125-
CountDownLatch latch = new CountDownLatch(1);
126-
AtomicReference<List<Map<String, Object>>> resultRef = new AtomicReference<>();
127-
AtomicReference<Exception> exceptionRef = new AtomicReference<>();
123+
PlainActionFuture<List<Map<String, Object>>> future = PlainActionFuture.newFuture();
128124

129125
SentenceHighlightingRequest request = SentenceHighlightingRequest.builder()
130126
.modelId(modelId)
131127
.question(question)
132128
.context(context)
133129
.build();
134130

135-
mlCommonsClient.inferenceSentenceHighlighting(request, ActionListener.wrap(result -> {
136-
resultRef.set(result);
137-
latch.countDown();
138-
}, exception -> {
139-
exceptionRef.set(exception);
140-
latch.countDown();
141-
}));
131+
mlCommonsClient.inferenceSentenceHighlighting(request, future);
142132

143133
try {
144-
latch.await();
145-
} catch (InterruptedException e) {
146-
Thread.currentThread().interrupt();
134+
return future.actionGet();
135+
} catch (Exception e) {
147136
throw new OpenSearchException(
148-
String.format(Locale.ROOT, "Interrupted while waiting for sentence highlighting inference from model [%s]", modelId),
137+
String.format(Locale.ROOT, "Error during sentence highlighting inference from model [%s]", modelId),
149138
e
150139
);
151140
}
152-
153-
if (exceptionRef.get() != null) {
154-
throw new OpenSearchException("Error during sentence highlighting inference", exceptionRef.get());
155-
}
156-
157-
return resultRef.get();
158141
}
159142

160143
/**

src/main/java/org/opensearch/neuralsearch/highlight/extractor/QueryTextExtractorRegistry.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public QueryTextExtractorRegistry() {
3232
/**
3333
* Initializes the registry with default extractors
3434
*/
35-
public void initialize() {
35+
private void initialize() {
3636
register(NeuralKNNQuery.class, new NeuralQueryTextExtractor());
3737
register(TermQuery.class, new TermQueryTextExtractor());
3838

src/test/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterManagerTests.java

+3-69
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import org.apache.lucene.search.BooleanClause;
99
import org.apache.lucene.search.BooleanQuery;
1010
import org.apache.lucene.search.TermQuery;
11-
import org.opensearch.OpenSearchException;
1211
import org.opensearch.core.action.ActionListener;
1312
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1413
import org.opensearch.neuralsearch.processor.SentenceHighlightingRequest;
@@ -19,16 +18,14 @@
1918
import java.util.HashMap;
2019
import java.util.List;
2120
import java.util.Map;
22-
import java.util.concurrent.CountDownLatch;
23-
import java.util.concurrent.TimeUnit;
2421

2522
import static org.mockito.ArgumentMatchers.any;
2623
import static org.mockito.Mockito.doAnswer;
2724
import static org.mockito.Mockito.mock;
2825
import static org.mockito.Mockito.when;
2926

3027
/**
31-
* Tests for the NeuralHighlighterManager class
28+
* Tests for the NeuralHighlighterEngine class
3229
*/
3330
public class NeuralHighlighterManagerTests extends OpenSearchTestCase {
3431

@@ -37,14 +34,14 @@ public class NeuralHighlighterManagerTests extends OpenSearchTestCase {
3734
private static final String TEST_CONTENT = "This is a test content. For highlighting purposes. With multiple sentences.";
3835
private static final String TEST_QUERY = "test content";
3936

40-
private NeuralHighlighterManager manager;
37+
private NeuralHighlighterEngine manager;
4138
private MLCommonsClientAccessor mlCommonsClientAccessor;
4239

4340
@Override
4441
public void setUp() throws Exception {
4542
super.setUp();
4643
mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
47-
manager = new NeuralHighlighterManager(mlCommonsClientAccessor);
44+
manager = new NeuralHighlighterEngine(mlCommonsClientAccessor);
4845

4946
// Setup default mock behavior
5047
setupDefaultMockBehavior();
@@ -213,67 +210,4 @@ public void testApplyHighlightingWithInvalidPositions() {
213210
// Should only apply the valid highlight
214211
assertEquals("Should only apply valid highlights", "<em>This</em> is a test string", result);
215212
}
216-
217-
public void testFetchModelResultsWithTimeout() throws Exception {
218-
// Create a custom mock that delays the response
219-
MLCommonsClientAccessor delayedMlClient = mock(MLCommonsClientAccessor.class);
220-
NeuralHighlighterManager customManager = new NeuralHighlighterManager(delayedMlClient);
221-
222-
// Use a CountDownLatch to control the test timing
223-
CountDownLatch latch = new CountDownLatch(1);
224-
225-
// Mock response with delay
226-
doAnswer(invocation -> {
227-
ActionListener<List<Map<String, Object>>> listener = invocation.getArgument(1);
228-
229-
// Start a new thread to delay the response
230-
new Thread(() -> {
231-
try {
232-
// Simulate a delay longer than any reasonable timeout
233-
Thread.sleep(500);
234-
235-
// Create mock response
236-
List<Map<String, Object>> mockResponse = new ArrayList<>();
237-
Map<String, Object> resultMap = new HashMap<>();
238-
List<Map<String, Object>> highlights = new ArrayList<>();
239-
240-
Map<String, Object> highlight = new HashMap<>();
241-
highlight.put("start", 0);
242-
highlight.put("end", 4);
243-
highlights.add(highlight);
244-
245-
resultMap.put("highlights", highlights);
246-
mockResponse.add(resultMap);
247-
248-
listener.onResponse(mockResponse);
249-
} catch (InterruptedException e) {
250-
listener.onFailure(e);
251-
} finally {
252-
latch.countDown();
253-
}
254-
}).start();
255-
256-
return null;
257-
}).when(delayedMlClient).inferenceSentenceHighlighting(any(SentenceHighlightingRequest.class), any());
258-
259-
// Call the method in a separate thread so we can interrupt it
260-
Thread testThread = new Thread(() -> {
261-
try {
262-
customManager.fetchModelResults(MODEL_ID, TEST_QUERY, TEST_CONTENT);
263-
fail("Should have been interrupted");
264-
} catch (OpenSearchException e) {
265-
// Expected exception
266-
assertTrue(e.getMessage().contains("Interrupted while waiting"));
267-
}
268-
});
269-
270-
testThread.start();
271-
272-
// Wait a bit and then interrupt the thread
273-
Thread.sleep(100);
274-
testThread.interrupt();
275-
276-
// Wait for the mock to finish
277-
latch.await(1, TimeUnit.SECONDS);
278-
}
279213
}

src/test/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterTests.java

+7-71
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import java.util.HashMap;
3030
import java.util.List;
3131
import java.util.Map;
32-
import java.util.concurrent.CountDownLatch;
33-
import java.util.concurrent.TimeUnit;
3432

3533
import static org.mockito.ArgumentMatchers.any;
3634
import static org.mockito.Mockito.doAnswer;
@@ -50,7 +48,7 @@ public class NeuralHighlighterTests extends OpenSearchTestCase {
5048
private NeuralHighlighter highlighter;
5149
private MLCommonsClientAccessor mlCommonsClientAccessor;
5250
private MappedFieldType fieldType;
53-
private NeuralHighlighterManager manager;
51+
private NeuralHighlighterEngine manager;
5452

5553
@Override
5654
public void setUp() throws Exception {
@@ -59,7 +57,7 @@ public void setUp() throws Exception {
5957
highlighter = new NeuralHighlighter();
6058
highlighter.initialize(mlCommonsClientAccessor);
6159
fieldType = new TextFieldMapper.TextFieldType(TEST_FIELD);
62-
manager = new NeuralHighlighterManager(mlCommonsClientAccessor);
60+
manager = new NeuralHighlighterEngine(mlCommonsClientAccessor);
6361

6462
// Setup default mock behavior
6563
setupDefaultMockBehavior();
@@ -199,7 +197,10 @@ public void testMultipleInitialization() {
199197
IllegalStateException exception = expectThrows(IllegalStateException.class, () -> testHighlighter.initialize(mlClient2));
200198

201199
assertNotNull(exception);
202-
assertEquals("NeuralHighlighter has already been initialized. Multiple initializations are not permitted.", exception.getMessage());
200+
assertEquals(
201+
"NeuralHighlighterEngine has already been initialized. Multiple initializations are not permitted.",
202+
exception.getMessage()
203+
);
203204
}
204205

205206
public void testUninitializedHighlighter() {
@@ -211,7 +212,7 @@ public void testUninitializedHighlighter() {
211212
assertTrue(exception.getMessage().contains("has not been initialized"));
212213
}
213214

214-
// Tests for the NeuralHighlighterManager class
215+
// Tests for the NeuralHighlighterEngine class
215216

216217
public void testGetModelId() {
217218
Map<String, Object> options = new HashMap<>();
@@ -353,71 +354,6 @@ public void testApplyHighlightingWithInvalidPositions() {
353354
assertEquals("Should only apply valid highlights", "<em>This</em> is a test string", result);
354355
}
355356

356-
public void testFetchHighlightingResultsWithTimeout() throws Exception {
357-
// Create a custom mock that delays the response
358-
MLCommonsClientAccessor delayedMlClient = mock(MLCommonsClientAccessor.class);
359-
NeuralHighlighterManager customManager = new NeuralHighlighterManager(delayedMlClient);
360-
361-
// Use a CountDownLatch to control the test timing
362-
CountDownLatch latch = new CountDownLatch(1);
363-
364-
// Mock response with delay
365-
doAnswer(invocation -> {
366-
ActionListener<List<Map<String, Object>>> listener = invocation.getArgument(1);
367-
368-
// Start a new thread to delay the response
369-
new Thread(() -> {
370-
try {
371-
// Simulate a delay longer than any reasonable timeout
372-
Thread.sleep(500);
373-
374-
// Create mock response
375-
List<Map<String, Object>> mockResponse = new ArrayList<>();
376-
Map<String, Object> resultMap = new HashMap<>();
377-
List<Map<String, Object>> highlights = new ArrayList<>();
378-
379-
Map<String, Object> highlight = new HashMap<>();
380-
highlight.put("start", 0);
381-
highlight.put("end", 4);
382-
highlights.add(highlight);
383-
384-
resultMap.put("highlights", highlights);
385-
mockResponse.add(resultMap);
386-
387-
listener.onResponse(mockResponse);
388-
} catch (InterruptedException e) {
389-
listener.onFailure(e);
390-
} finally {
391-
latch.countDown();
392-
}
393-
}).start();
394-
395-
return null;
396-
}).when(delayedMlClient).inferenceSentenceHighlighting(any(SentenceHighlightingRequest.class), any());
397-
398-
// Call the method in a separate thread so we can interrupt it
399-
Thread testThread = new Thread(() -> {
400-
try {
401-
customManager.fetchModelResults(MODEL_ID, TEST_QUERY, TEST_CONTENT);
402-
fail("Should have been interrupted");
403-
} catch (OpenSearchException e) {
404-
// Expected exception
405-
assertTrue(e.getMessage().contains("Interrupted while waiting"));
406-
}
407-
});
408-
409-
testThread.start();
410-
411-
// Wait a bit and then interrupt the thread
412-
Thread.sleep(100);
413-
testThread.interrupt();
414-
415-
// Wait for the mock to finish
416-
latch.await(1, TimeUnit.SECONDS);
417-
}
418-
419-
// Integration and error handling tests
420-
421357
public void testIntegrationWithTermQuery() {
422358
// Integration test logic for term queries
423359
TermQuery query = new TermQuery(new Term(TEST_FIELD, "test"));

src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public void testNeuralQueryWithNeuralHighlighting() {
251251
List<String> highlightedFields = (List<String>) highlight.get(TEST_TEXT_FIELD_NAME_1);
252252
assertNotNull("Highlight should contain the requested field", highlightedFields);
253253

254-
String highlightedText = highlightedFields.get(0);
254+
String highlightedText = highlightedFields.getFirst();
255255
assertTrue(
256256
"Highlighted text should contain content from the original text",
257257
highlightedText.contains("neural highlighting")

0 commit comments

Comments
 (0)