@@ -52,7 +52,7 @@ public class MLCommonsClientAccessor {
52
52
public void inferenceSentence (
53
53
@ NonNull final String modelId ,
54
54
@ NonNull final String inputText ,
55
- @ NonNull final ActionListener <List <Float >> listener
55
+ @ NonNull final ActionListener <List <Number >> listener
56
56
) {
57
57
inferenceSentences (TARGET_RESPONSE_FILTERS , modelId , List .of (inputText ), ActionListener .wrap (response -> {
58
58
if (response .size () != 1 ) {
@@ -82,7 +82,7 @@ public void inferenceSentence(
82
82
public void inferenceSentences (
83
83
@ NonNull final String modelId ,
84
84
@ NonNull final List <String > inputText ,
85
- @ NonNull final ActionListener <List <List <Float >>> listener
85
+ @ NonNull final ActionListener <List <List <Number >>> listener
86
86
) {
87
87
inferenceSentences (TARGET_RESPONSE_FILTERS , modelId , inputText , listener );
88
88
}
@@ -103,7 +103,7 @@ public void inferenceSentences(
103
103
@ NonNull final List <String > targetResponseFilters ,
104
104
@ NonNull final String modelId ,
105
105
@ NonNull final List <String > inputText ,
106
- @ NonNull final ActionListener <List <List <Float >>> listener
106
+ @ NonNull final ActionListener <List <List <Number >>> listener
107
107
) {
108
108
retryableInferenceSentencesWithVectorResult (targetResponseFilters , modelId , inputText , 0 , listener );
109
109
}
@@ -128,7 +128,7 @@ public void inferenceSentencesWithMapResult(
128
128
public void inferenceSentences (
129
129
@ NonNull final String modelId ,
130
130
@ NonNull final Map <String , String > inputObjects ,
131
- @ NonNull final ActionListener <List <Float >> listener
131
+ @ NonNull final ActionListener <List <Number >> listener
132
132
) {
133
133
retryableInferenceSentencesWithSingleVectorResult (TARGET_RESPONSE_FILTERS , modelId , inputObjects , 0 , listener );
134
134
}
@@ -177,11 +177,11 @@ private void retryableInferenceSentencesWithVectorResult(
177
177
final String modelId ,
178
178
final List <String > inputText ,
179
179
final int retryTime ,
180
- final ActionListener <List <List <Float >>> listener
180
+ final ActionListener <List <List <Number >>> listener
181
181
) {
182
182
MLInput mlInput = createMLTextInput (targetResponseFilters , inputText );
183
183
mlClient .predict (modelId , mlInput , ActionListener .wrap (mlOutput -> {
184
- final List <List <Float >> vector = buildVectorFromResponse (mlOutput );
184
+ final List <List <Number >> vector = buildVectorFromResponse (mlOutput );
185
185
listener .onResponse (vector );
186
186
}, e -> {
187
187
if (RetryUtil .shouldRetry (e , retryTime )) {
@@ -202,7 +202,8 @@ private void retryableInferenceSimilarityWithVectorResult(
202
202
) {
203
203
MLInput mlInput = createMLTextPairsInput (queryText , inputText );
204
204
mlClient .predict (modelId , mlInput , ActionListener .wrap (mlOutput -> {
205
- final List <Float > scores = buildVectorFromResponse (mlOutput ).stream ().map (v -> v .get (0 )).collect (Collectors .toList ());
205
+ final List <List <Float >> tensors = buildVectorFromResponse (mlOutput );
206
+ final List <Float > scores = tensors .stream ().map (v -> v .get (0 )).collect (Collectors .toList ());
206
207
listener .onResponse (scores );
207
208
}, e -> {
208
209
if (RetryUtil .shouldRetry (e , retryTime )) {
@@ -224,14 +225,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
224
225
return new MLInput (FunctionName .TEXT_SIMILARITY , null , inputDataset );
225
226
}
226
227
227
- private List <List <Float >> buildVectorFromResponse (MLOutput mlOutput ) {
228
- final List <List <Float >> vector = new ArrayList <>();
228
+ private < T extends Number > List <List <T >> buildVectorFromResponse (MLOutput mlOutput ) {
229
+ final List <List <T >> vector = new ArrayList <>();
229
230
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput ) mlOutput ;
230
231
final List <ModelTensors > tensorOutputList = modelTensorOutput .getMlModelOutputs ();
231
232
for (final ModelTensors tensors : tensorOutputList ) {
232
233
final List <ModelTensor > tensorsList = tensors .getMlModelTensors ();
233
234
for (final ModelTensor tensor : tensorsList ) {
234
- vector .add (Arrays .stream (tensor .getData ()).map (value -> (Float ) value ).collect (Collectors .toList ()));
235
+ vector .add (Arrays .stream (tensor .getData ()).map (value -> (T ) value ).collect (Collectors .toList ()));
235
236
}
236
237
}
237
238
return vector ;
@@ -255,8 +256,8 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
255
256
return resultMaps ;
256
257
}
257
258
258
- private List <Float > buildSingleVectorFromResponse (final MLOutput mlOutput ) {
259
- final List <List <Float >> vector = buildVectorFromResponse (mlOutput );
259
+ private < T extends Number > List <T > buildSingleVectorFromResponse (final MLOutput mlOutput ) {
260
+ final List <List <T >> vector = buildVectorFromResponse (mlOutput );
260
261
return vector .isEmpty () ? new ArrayList <>() : vector .get (0 );
261
262
}
262
263
@@ -265,11 +266,11 @@ private void retryableInferenceSentencesWithSingleVectorResult(
265
266
final String modelId ,
266
267
final Map <String , String > inputObjects ,
267
268
final int retryTime ,
268
- final ActionListener <List <Float >> listener
269
+ final ActionListener <List <Number >> listener
269
270
) {
270
271
MLInput mlInput = createMLMultimodalInput (targetResponseFilters , inputObjects );
271
272
mlClient .predict (modelId , mlInput , ActionListener .wrap (mlOutput -> {
272
- final List <Float > vector = buildSingleVectorFromResponse (mlOutput );
273
+ final List <Number > vector = buildSingleVectorFromResponse (mlOutput );
273
274
log .debug ("Inference Response for input sentence is : {} " , vector );
274
275
listener .onResponse (vector );
275
276
}, e -> {
0 commit comments