29
29
import java .util .HashMap ;
30
30
import java .util .List ;
31
31
import java .util .Map ;
32
- import java .util .concurrent .CountDownLatch ;
33
- import java .util .concurrent .TimeUnit ;
34
32
35
33
import static org .mockito .ArgumentMatchers .any ;
36
34
import static org .mockito .Mockito .doAnswer ;
@@ -50,7 +48,7 @@ public class NeuralHighlighterTests extends OpenSearchTestCase {
50
48
private NeuralHighlighter highlighter ;
51
49
private MLCommonsClientAccessor mlCommonsClientAccessor ;
52
50
private MappedFieldType fieldType ;
53
- private NeuralHighlighterManager manager ;
51
+ private NeuralHighlighterEngine manager ;
54
52
55
53
@ Override
56
54
public void setUp () throws Exception {
@@ -59,7 +57,7 @@ public void setUp() throws Exception {
59
57
highlighter = new NeuralHighlighter ();
60
58
highlighter .initialize (mlCommonsClientAccessor );
61
59
fieldType = new TextFieldMapper .TextFieldType (TEST_FIELD );
62
- manager = new NeuralHighlighterManager (mlCommonsClientAccessor );
60
+ manager = new NeuralHighlighterEngine (mlCommonsClientAccessor );
63
61
64
62
// Setup default mock behavior
65
63
setupDefaultMockBehavior ();
@@ -199,7 +197,10 @@ public void testMultipleInitialization() {
199
197
IllegalStateException exception = expectThrows (IllegalStateException .class , () -> testHighlighter .initialize (mlClient2 ));
200
198
201
199
assertNotNull (exception );
202
- assertEquals ("NeuralHighlighter has already been initialized. Multiple initializations are not permitted." , exception .getMessage ());
200
+ assertEquals (
201
+ "NeuralHighlighterEngine has already been initialized. Multiple initializations are not permitted." ,
202
+ exception .getMessage ()
203
+ );
203
204
}
204
205
205
206
public void testUninitializedHighlighter () {
@@ -211,7 +212,7 @@ public void testUninitializedHighlighter() {
211
212
assertTrue (exception .getMessage ().contains ("has not been initialized" ));
212
213
}
213
214
214
- // Tests for the NeuralHighlighterManager class
215
+ // Tests for the NeuralHighlighterEngine class
215
216
216
217
public void testGetModelId () {
217
218
Map <String , Object > options = new HashMap <>();
@@ -353,71 +354,6 @@ public void testApplyHighlightingWithInvalidPositions() {
353
354
assertEquals ("Should only apply valid highlights" , "<em>This</em> is a test string" , result );
354
355
}
355
356
356
- public void testFetchHighlightingResultsWithTimeout () throws Exception {
357
- // Create a custom mock that delays the response
358
- MLCommonsClientAccessor delayedMlClient = mock (MLCommonsClientAccessor .class );
359
- NeuralHighlighterManager customManager = new NeuralHighlighterManager (delayedMlClient );
360
-
361
- // Use a CountDownLatch to control the test timing
362
- CountDownLatch latch = new CountDownLatch (1 );
363
-
364
- // Mock response with delay
365
- doAnswer (invocation -> {
366
- ActionListener <List <Map <String , Object >>> listener = invocation .getArgument (1 );
367
-
368
- // Start a new thread to delay the response
369
- new Thread (() -> {
370
- try {
371
- // Simulate a delay longer than any reasonable timeout
372
- Thread .sleep (500 );
373
-
374
- // Create mock response
375
- List <Map <String , Object >> mockResponse = new ArrayList <>();
376
- Map <String , Object > resultMap = new HashMap <>();
377
- List <Map <String , Object >> highlights = new ArrayList <>();
378
-
379
- Map <String , Object > highlight = new HashMap <>();
380
- highlight .put ("start" , 0 );
381
- highlight .put ("end" , 4 );
382
- highlights .add (highlight );
383
-
384
- resultMap .put ("highlights" , highlights );
385
- mockResponse .add (resultMap );
386
-
387
- listener .onResponse (mockResponse );
388
- } catch (InterruptedException e ) {
389
- listener .onFailure (e );
390
- } finally {
391
- latch .countDown ();
392
- }
393
- }).start ();
394
-
395
- return null ;
396
- }).when (delayedMlClient ).inferenceSentenceHighlighting (any (SentenceHighlightingRequest .class ), any ());
397
-
398
- // Call the method in a separate thread so we can interrupt it
399
- Thread testThread = new Thread (() -> {
400
- try {
401
- customManager .fetchModelResults (MODEL_ID , TEST_QUERY , TEST_CONTENT );
402
- fail ("Should have been interrupted" );
403
- } catch (OpenSearchException e ) {
404
- // Expected exception
405
- assertTrue (e .getMessage ().contains ("Interrupted while waiting" ));
406
- }
407
- });
408
-
409
- testThread .start ();
410
-
411
- // Wait a bit and then interrupt the thread
412
- Thread .sleep (100 );
413
- testThread .interrupt ();
414
-
415
- // Wait for the mock to finish
416
- latch .await (1 , TimeUnit .SECONDS );
417
- }
418
-
419
- // Integration and error handling tests
420
-
421
357
public void testIntegrationWithTermQuery () {
422
358
// Integration test logic for term queries
423
359
TermQuery query = new TermQuery (new Term (TEST_FIELD , "test" ));
0 commit comments