36
36
public class MLCommonsClientAccessorTests extends OpenSearchTestCase {
37
37
38
38
@ Mock
39
- private ActionListener <List <List <Float >>> resultListener ;
39
+ private ActionListener <List <List <Number >>> resultListener ;
40
40
41
41
@ Mock
42
- private ActionListener <List <Float >> singleSentenceResultListener ;
42
+ private ActionListener <List <Number >> singleSentenceResultListener ;
43
+
44
+ @ Mock
45
+ private ActionListener <List <Float >> similarityResultListener ;
43
46
44
47
@ Mock
45
48
private MachineLearningNodeClient client ;
@@ -53,7 +56,7 @@ public void setup() {
53
56
}
54
57
55
58
public void testInferenceSentence_whenValidInput_thenSuccess () {
56
- final List <Float > vector = new ArrayList <>(List .of (TestCommonConstants .PREDICT_VECTOR_ARRAY ));
59
+ final List <Number > vector = new ArrayList <>(List .of (TestCommonConstants .PREDICT_VECTOR_ARRAY ));
57
60
Mockito .doAnswer (invocation -> {
58
61
final ActionListener <MLOutput > actionListener = invocation .getArgument (2 );
59
62
actionListener .onResponse (createModelTensorOutput (TestCommonConstants .PREDICT_VECTOR_ARRAY ));
@@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
69
72
}
70
73
71
74
public void testInferenceSentences_whenValidInputThenSuccess () {
72
- final List <List <Float >> vectorList = new ArrayList <>();
75
+ final List <List <Number >> vectorList = new ArrayList <>();
73
76
vectorList .add (Arrays .asList (TestCommonConstants .PREDICT_VECTOR_ARRAY ));
74
77
Mockito .doAnswer (invocation -> {
75
78
final ActionListener <MLOutput > actionListener = invocation .getArgument (2 );
@@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
85
88
}
86
89
87
90
public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList () {
88
- final List <List <Float >> vectorList = new ArrayList <>();
91
+ final List <List <Number >> vectorList = new ArrayList <>();
89
92
vectorList .add (Collections .emptyList ());
90
93
Mockito .doAnswer (invocation -> {
91
94
final ActionListener <MLOutput > actionListener = invocation .getArgument (2 );
@@ -127,17 +130,17 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() {
127
130
return null ;
128
131
}).when (client ).predict (Mockito .eq (TestCommonConstants .MODEL_ID ), Mockito .isA (MLInput .class ), Mockito .isA (ActionListener .class ));
129
132
130
- accessor .inferenceSimilarity (TestCommonConstants .SIMILARITY_INFERENCE_REQUEST , singleSentenceResultListener );
133
+ accessor .inferenceSimilarity (TestCommonConstants .SIMILARITY_INFERENCE_REQUEST , similarityResultListener );
131
134
132
135
// Verify client.predict is called 4 times (1 initial + 3 retries)
133
136
Mockito .verify (client , times (4 ))
134
137
.predict (Mockito .eq (TestCommonConstants .MODEL_ID ), Mockito .isA (MLInput .class ), Mockito .isA (ActionListener .class ));
135
138
136
139
// Verify failure is propagated to the listener after all retries
137
- Mockito .verify (singleSentenceResultListener ).onFailure (nodeNodeConnectedException );
140
+ Mockito .verify (similarityResultListener ).onFailure (nodeNodeConnectedException );
138
141
139
142
// Ensure no additional interactions with the listener
140
- Mockito .verifyNoMoreInteractions (singleSentenceResultListener );
143
+ Mockito .verifyNoMoreInteractions (similarityResultListener );
141
144
}
142
145
143
146
public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure () {
@@ -288,7 +291,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
288
291
}
289
292
290
293
public void testInferenceMultimodal_whenValidInput_thenSuccess () {
291
- final List <Float > vector = new ArrayList <>(List .of (TestCommonConstants .PREDICT_VECTOR_ARRAY ));
294
+ final List <Number > vector = new ArrayList <>(List .of (TestCommonConstants .PREDICT_VECTOR_ARRAY ));
292
295
Mockito .doAnswer (invocation -> {
293
296
final ActionListener <MLOutput > actionListener = invocation .getArgument (2 );
294
297
actionListener .onResponse (createModelTensorOutput (TestCommonConstants .PREDICT_VECTOR_ARRAY ));
@@ -353,12 +356,12 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
353
356
return null ;
354
357
}).when (client ).predict (Mockito .eq (TestCommonConstants .MODEL_ID ), Mockito .isA (MLInput .class ), Mockito .isA (ActionListener .class ));
355
358
356
- accessor .inferenceSimilarity (TestCommonConstants .SIMILARITY_INFERENCE_REQUEST , singleSentenceResultListener );
359
+ accessor .inferenceSimilarity (TestCommonConstants .SIMILARITY_INFERENCE_REQUEST , similarityResultListener );
357
360
358
361
Mockito .verify (client )
359
362
.predict (Mockito .eq (TestCommonConstants .MODEL_ID ), Mockito .isA (MLInput .class ), Mockito .isA (ActionListener .class ));
360
- Mockito .verify (singleSentenceResultListener ).onResponse (vector );
361
- Mockito .verifyNoMoreInteractions (singleSentenceResultListener );
363
+ Mockito .verify (similarityResultListener ).onResponse (vector );
364
+ Mockito .verifyNoMoreInteractions (similarityResultListener );
362
365
}
363
366
364
367
public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail () {
@@ -369,12 +372,12 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
369
372
return null ;
370
373
}).when (client ).predict (Mockito .eq (TestCommonConstants .MODEL_ID ), Mockito .isA (MLInput .class ), Mockito .isA (ActionListener .class ));
371
374
372
- accessor .inferenceSimilarity (TestCommonConstants .SIMILARITY_INFERENCE_REQUEST , singleSentenceResultListener );
375
+ accessor .inferenceSimilarity (TestCommonConstants .SIMILARITY_INFERENCE_REQUEST , similarityResultListener );
373
376
374
377
Mockito .verify (client )
375
378
.predict (Mockito .eq (TestCommonConstants .MODEL_ID ), Mockito .isA (MLInput .class ), Mockito .isA (ActionListener .class ));
376
- Mockito .verify (singleSentenceResultListener ).onFailure (exception );
377
- Mockito .verifyNoMoreInteractions (singleSentenceResultListener );
379
+ Mockito .verify (similarityResultListener ).onFailure (exception );
380
+ Mockito .verifyNoMoreInteractions (similarityResultListener );
378
381
}
379
382
380
383
private ModelTensorOutput createModelTensorOutput (final Float [] output ) {
0 commit comments