@@ -55,7 +55,7 @@ public class MLCommonsClientAccessor {
55
55
public void inferenceSentence (
56
56
@ NonNull final String modelId ,
57
57
@ NonNull final String inputText ,
58
- @ NonNull final ActionListener <List <Float >> listener
58
+ @ NonNull final ActionListener <List <Number >> listener
59
59
) {
60
60
61
61
inferenceSentences (
@@ -87,7 +87,7 @@ public void inferenceSentence(
87
87
*/
88
88
public void inferenceSentences (
89
89
@ NonNull final TextInferenceRequest inferenceRequest ,
90
- @ NonNull final ActionListener <List <List <Float >>> listener
90
+ @ NonNull final ActionListener <List <List <Number >>> listener
91
91
) {
92
92
retryableInferenceSentencesWithVectorResult (inferenceRequest , 0 , listener );
93
93
}
@@ -107,7 +107,7 @@ public void inferenceSentencesWithMapResult(
107
107
* @param inferenceRequest {@link InferenceRequest}
108
108
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
109
109
*/
110
- public void inferenceSentencesMap (@ NonNull MapInferenceRequest inferenceRequest , @ NonNull final ActionListener <List <Float >> listener ) {
110
+ public void inferenceSentencesMap (@ NonNull MapInferenceRequest inferenceRequest , @ NonNull final ActionListener <List <Number >> listener ) {
111
111
retryableInferenceSentencesWithSingleVectorResult (inferenceRequest , 0 , listener );
112
112
}
113
113
@@ -148,11 +148,11 @@ private void retryableInferenceSentencesWithMapResult(
148
148
private void retryableInferenceSentencesWithVectorResult (
149
149
final TextInferenceRequest inferenceRequest ,
150
150
final int retryTime ,
151
- final ActionListener <List <List <Float >>> listener
151
+ final ActionListener <List <List <Number >>> listener
152
152
) {
153
153
MLInput mlInput = createMLTextInput (inferenceRequest .getTargetResponseFilters (), inferenceRequest .getInputTexts ());
154
154
mlClient .predict (inferenceRequest .getModelId (), mlInput , ActionListener .wrap (mlOutput -> {
155
- final List <List <Float >> vector = buildVectorFromResponse (mlOutput );
155
+ final List <List <Number >> vector = buildVectorFromResponse (mlOutput );
156
156
listener .onResponse (vector );
157
157
},
158
158
e -> RetryUtil .handleRetryOrFailure (
@@ -167,11 +167,11 @@ private void retryableInferenceSentencesWithVectorResult(
167
167
private void retryableInferenceSimilarityWithVectorResult (
168
168
final SimilarityInferenceRequest inferenceRequest ,
169
169
final int retryTime ,
170
- final ActionListener <List <Float >> listener
170
+ final ActionListener <List <Number >> listener
171
171
) {
172
172
MLInput mlInput = createMLTextPairsInput (inferenceRequest .getQueryText (), inferenceRequest .getInputTexts ());
173
173
mlClient .predict (inferenceRequest .getModelId (), mlInput , ActionListener .wrap (mlOutput -> {
174
- final List <Float > scores = buildVectorFromResponse (mlOutput ).stream ().map (v -> v .get (0 )).collect (Collectors .toList ());
174
+ final List <Number > scores = buildVectorFromResponse (mlOutput ).stream ().map (v -> v .get (0 )).collect (Collectors .toList ());
175
175
listener .onResponse (scores );
176
176
},
177
177
e -> RetryUtil .handleRetryOrFailure (
@@ -194,14 +194,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
194
194
return new MLInput (FunctionName .TEXT_SIMILARITY , null , inputDataset );
195
195
}
196
196
197
- private List <List <Float >> buildVectorFromResponse (MLOutput mlOutput ) {
198
- final List <List <Float >> vector = new ArrayList <>();
197
+ private < T extends Number > List <List <T >> buildVectorFromResponse (MLOutput mlOutput ) {
198
+ final List <List <T >> vector = new ArrayList <>();
199
199
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) mlOutput ;
200
200
final List <ModelTensors > tensorOutputList = modelTensorOutput .getMlModelOutputs ();
201
201
for (final ModelTensors tensors : tensorOutputList ) {
202
202
final List <ModelTensor > tensorsList = tensors .getMlModelTensors ();
203
203
for (final ModelTensor tensor : tensorsList ) {
204
- vector .add (Arrays .stream (tensor .getData ()).map (value -> (Float ) value ).collect (Collectors .toList ()));
204
+ vector .add (Arrays .stream (tensor .getData ()).map (value -> (T ) value ).collect (Collectors .toList ()));
205
205
}
206
206
}
207
207
return vector ;
@@ -225,19 +225,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
225
225
return resultMaps ;
226
226
}
227
227
228
- private List <Float > buildSingleVectorFromResponse (final MLOutput mlOutput ) {
229
- final List <List <Float >> vector = buildVectorFromResponse (mlOutput );
228
+ private < T extends Number > List <T > buildSingleVectorFromResponse (final MLOutput mlOutput ) {
229
+ final List <List <T >> vector = buildVectorFromResponse (mlOutput );
230
230
return vector .isEmpty () ? new ArrayList <>() : vector .get (0 );
231
231
}
232
232
233
233
private void retryableInferenceSentencesWithSingleVectorResult (
234
234
final MapInferenceRequest inferenceRequest ,
235
235
final int retryTime ,
236
- final ActionListener <List <Float >> listener
236
+ final ActionListener <List <Number >> listener
237
237
) {
238
238
MLInput mlInput = createMLMultimodalInput (inferenceRequest .getTargetResponseFilters (), inferenceRequest .getInputObjects ());
239
239
mlClient .predict (inferenceRequest .getModelId (), mlInput , ActionListener .wrap (mlOutput -> {
240
- final List <Float > vector = buildSingleVectorFromResponse (mlOutput );
240
+ final List <Number > vector = buildSingleVectorFromResponse (mlOutput );
241
241
log .debug ("Inference Response for input sentence is : {} " , vector );
242
242
listener .onResponse (vector );
243
243
},
0 commit comments