26
26
import org .opensearch .ml .common .output .model .ModelTensor ;
27
27
import org .opensearch .ml .common .output .model .ModelTensorOutput ;
28
28
import org .opensearch .ml .common .output .model .ModelTensors ;
29
+ import org .opensearch .neuralsearch .processor .InferenceRequest ;
30
+ import org .opensearch .neuralsearch .processor .MapInferenceRequest ;
31
+ import org .opensearch .neuralsearch .processor .SimilarityInferenceRequest ;
32
+ import org .opensearch .neuralsearch .processor .TextInferenceRequest ;
29
33
import org .opensearch .neuralsearch .util .RetryUtil ;
30
34
31
35
import lombok .NonNull ;
38
42
@ RequiredArgsConstructor
39
43
@ Log4j2
40
44
public class MLCommonsClientAccessor {
41
- private static final List <String > TARGET_RESPONSE_FILTERS = List .of ("sentence_embedding" );
42
45
private final MachineLearningNodeClient mlClient ;
43
46
44
47
/**
45
48
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
46
49
* point vector as a response.
47
50
*
48
51
* @param modelId {@link String}
49
- * @param inputText {@link List} of {@link String} on which inference needs to happen
52
+ * @param inputText {@link String}
50
53
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
51
54
*/
52
55
public void inferenceSentence (
53
56
@ NonNull final String modelId ,
54
57
@ NonNull final String inputText ,
55
58
@ NonNull final ActionListener <List <Float >> listener
56
59
) {
57
- inferenceSentences (TARGET_RESPONSE_FILTERS , modelId , List .of (inputText ), ActionListener .wrap (response -> {
58
- if (response .size () != 1 ) {
59
- listener .onFailure (
60
- new IllegalStateException (
61
- "Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response .size () + "]"
62
- )
63
- );
64
- return ;
65
- }
66
60
67
- listener .onResponse (response .get (0 ));
68
- }, listener ::onFailure ));
69
- }
61
+ inferenceSentences (
62
+ TextInferenceRequest .builder ().modelId (modelId ).inputTexts (List .of (inputText )).build (),
63
+ ActionListener .wrap (response -> {
64
+ if (response .size () != 1 ) {
65
+ listener .onFailure (
66
+ new IllegalStateException (
67
+ "Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response .size () + "]"
68
+ )
69
+ );
70
+ return ;
71
+ }
70
72
71
- /**
72
- * Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the
73
- * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
74
- * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
75
- * inputText. We are not making this function generic enough to take any function or TaskType as currently we
76
- * need to run only TextEmbedding tasks only.
77
- *
78
- * @param modelId {@link String}
79
- * @param inputText {@link List} of {@link String} on which inference needs to happen
80
- * @param listener {@link ActionListener} which will be called when prediction is completed or errored out
81
- */
82
- public void inferenceSentences (
83
- @ NonNull final String modelId ,
84
- @ NonNull final List <String > inputText ,
85
- @ NonNull final ActionListener <List <List <Float >>> listener
86
- ) {
87
- inferenceSentences (TARGET_RESPONSE_FILTERS , modelId , inputText , listener );
73
+ listener .onResponse (response .getFirst ());
74
+ }, listener ::onFailure )
75
+ );
88
76
}
89
77
90
78
/**
@@ -94,121 +82,102 @@ public void inferenceSentences(
94
82
* inputText. We are not making this function generic enough to take any function or TaskType as currently we
95
83
* need to run only TextEmbedding tasks only.
96
84
*
97
- * @param targetResponseFilters {@link List} of {@link String} which filters out the responses
98
- * @param modelId {@link String}
99
- * @param inputText {@link List} of {@link String} on which inference needs to happen
85
+ * @param inferenceRequest {@link InferenceRequest}
100
86
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
101
87
*/
102
88
public void inferenceSentences (
103
- @ NonNull final List <String > targetResponseFilters ,
104
- @ NonNull final String modelId ,
105
- @ NonNull final List <String > inputText ,
89
+ @ NonNull final TextInferenceRequest inferenceRequest ,
106
90
@ NonNull final ActionListener <List <List <Float >>> listener
107
91
) {
108
- retryableInferenceSentencesWithVectorResult (targetResponseFilters , modelId , inputText , 0 , listener );
92
+ retryableInferenceSentencesWithVectorResult (inferenceRequest , 0 , listener );
109
93
}
110
94
111
95
public void inferenceSentencesWithMapResult (
112
- @ NonNull final String modelId ,
113
- @ NonNull final List <String > inputText ,
96
+ @ NonNull final TextInferenceRequest inferenceRequest ,
114
97
@ NonNull final ActionListener <List <Map <String , ?>>> listener
115
98
) {
116
- retryableInferenceSentencesWithMapResult (modelId , inputText , 0 , listener );
99
+ retryableInferenceSentencesWithMapResult (inferenceRequest , 0 , listener );
117
100
}
118
101
119
102
/**
120
103
* Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the
121
104
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
122
105
* using the actionListener which will have a list of floats in the order of inputText.
123
106
*
124
- * @param modelId {@link String}
125
- * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen
107
+ * @param inferenceRequest {@link InferenceRequest}
126
108
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
127
109
*/
128
- public void inferenceSentences (
129
- @ NonNull final String modelId ,
130
- @ NonNull final Map <String , String > inputObjects ,
131
- @ NonNull final ActionListener <List <Float >> listener
132
- ) {
133
- retryableInferenceSentencesWithSingleVectorResult (TARGET_RESPONSE_FILTERS , modelId , inputObjects , 0 , listener );
110
+ public void inferenceSentencesMap (@ NonNull MapInferenceRequest inferenceRequest , @ NonNull final ActionListener <List <Float >> listener ) {
111
+ retryableInferenceSentencesWithSingleVectorResult (inferenceRequest , 0 , listener );
134
112
}
135
113
136
114
/**
137
115
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
138
116
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
139
117
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
140
118
*
141
- * @param modelId {@link String} ML-Commons Model Id
142
- * @param queryText {@link String} The query to compare all the inputText to
143
- * @param inputText {@link List} of {@link String} The texts to compare to the query
119
+ * @param inferenceRequest {@link InferenceRequest}
144
120
* @param listener {@link ActionListener} receives the result of the inference
145
121
*/
146
122
public void inferenceSimilarity (
147
- @ NonNull final String modelId ,
148
- @ NonNull final String queryText ,
149
- @ NonNull final List <String > inputText ,
123
+ @ NonNull SimilarityInferenceRequest inferenceRequest ,
150
124
@ NonNull final ActionListener <List <Float >> listener
151
125
) {
152
- retryableInferenceSimilarityWithVectorResult (modelId , queryText , inputText , 0 , listener );
126
+ retryableInferenceSimilarityWithVectorResult (inferenceRequest , 0 , listener );
153
127
}
154
128
155
129
private void retryableInferenceSentencesWithMapResult (
156
- final String modelId ,
157
- final List <String > inputText ,
130
+ final TextInferenceRequest inferenceRequest ,
158
131
final int retryTime ,
159
132
final ActionListener <List <Map <String , ?>>> listener
160
133
) {
161
- MLInput mlInput = createMLTextInput (null , inputText );
162
- mlClient .predict (modelId , mlInput , ActionListener .wrap (mlOutput -> {
134
+ MLInput mlInput = createMLTextInput (null , inferenceRequest . getInputTexts () );
135
+ mlClient .predict (inferenceRequest . getModelId () , mlInput , ActionListener .wrap (mlOutput -> {
163
136
final List <Map <String , ?>> result = buildMapResultFromResponse (mlOutput );
164
137
listener .onResponse (result );
165
138
},
166
139
e -> RetryUtil .handleRetryOrFailure (
167
140
e ,
168
141
retryTime ,
169
- () -> retryableInferenceSentencesWithMapResult (modelId , inputText , retryTime + 1 , listener ),
142
+ () -> retryableInferenceSentencesWithMapResult (inferenceRequest , retryTime + 1 , listener ),
170
143
listener
171
144
)
172
145
));
173
146
}
174
147
175
148
private void retryableInferenceSentencesWithVectorResult (
176
- final List <String > targetResponseFilters ,
177
- final String modelId ,
178
- final List <String > inputText ,
149
+ final TextInferenceRequest inferenceRequest ,
179
150
final int retryTime ,
180
151
final ActionListener <List <List <Float >>> listener
181
152
) {
182
- MLInput mlInput = createMLTextInput (targetResponseFilters , inputText );
183
- mlClient .predict (modelId , mlInput , ActionListener .wrap (mlOutput -> {
153
+ MLInput mlInput = createMLTextInput (inferenceRequest . getTargetResponseFilters (), inferenceRequest . getInputTexts () );
154
+ mlClient .predict (inferenceRequest . getModelId () , mlInput , ActionListener .wrap (mlOutput -> {
184
155
final List <List <Float >> vector = buildVectorFromResponse (mlOutput );
185
156
listener .onResponse (vector );
186
157
},
187
158
e -> RetryUtil .handleRetryOrFailure (
188
159
e ,
189
160
retryTime ,
190
- () -> retryableInferenceSentencesWithVectorResult (targetResponseFilters , modelId , inputText , retryTime + 1 , listener ),
161
+ () -> retryableInferenceSentencesWithVectorResult (inferenceRequest , retryTime + 1 , listener ),
191
162
listener
192
163
)
193
164
));
194
165
}
195
166
196
167
private void retryableInferenceSimilarityWithVectorResult (
197
- final String modelId ,
198
- final String queryText ,
199
- final List <String > inputText ,
168
+ final SimilarityInferenceRequest inferenceRequest ,
200
169
final int retryTime ,
201
170
final ActionListener <List <Float >> listener
202
171
) {
203
- MLInput mlInput = createMLTextPairsInput (queryText , inputText );
204
- mlClient .predict (modelId , mlInput , ActionListener .wrap (mlOutput -> {
172
+ MLInput mlInput = createMLTextPairsInput (inferenceRequest . getQueryText (), inferenceRequest . getInputTexts () );
173
+ mlClient .predict (inferenceRequest . getModelId () , mlInput , ActionListener .wrap (mlOutput -> {
205
174
final List <Float > scores = buildVectorFromResponse (mlOutput ).stream ().map (v -> v .get (0 )).collect (Collectors .toList ());
206
175
listener .onResponse (scores );
207
176
},
208
177
e -> RetryUtil .handleRetryOrFailure (
209
178
e ,
210
179
retryTime ,
211
- () -> retryableInferenceSimilarityWithVectorResult (modelId , queryText , inputText , retryTime + 1 , listener ),
180
+ () -> retryableInferenceSimilarityWithVectorResult (inferenceRequest , retryTime + 1 , listener ),
212
181
listener
213
182
)
214
183
));
@@ -262,28 +231,20 @@ private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
262
231
}
263
232
264
233
private void retryableInferenceSentencesWithSingleVectorResult (
265
- final List <String > targetResponseFilters ,
266
- final String modelId ,
267
- final Map <String , String > inputObjects ,
234
+ final MapInferenceRequest inferenceRequest ,
268
235
final int retryTime ,
269
236
final ActionListener <List <Float >> listener
270
237
) {
271
- MLInput mlInput = createMLMultimodalInput (targetResponseFilters , inputObjects );
272
- mlClient .predict (modelId , mlInput , ActionListener .wrap (mlOutput -> {
238
+ MLInput mlInput = createMLMultimodalInput (inferenceRequest . getTargetResponseFilters (), inferenceRequest . getInputObjects () );
239
+ mlClient .predict (inferenceRequest . getModelId () , mlInput , ActionListener .wrap (mlOutput -> {
273
240
final List <Float > vector = buildSingleVectorFromResponse (mlOutput );
274
241
log .debug ("Inference Response for input sentence is : {} " , vector );
275
242
listener .onResponse (vector );
276
243
},
277
244
e -> RetryUtil .handleRetryOrFailure (
278
245
e ,
279
246
retryTime ,
280
- () -> retryableInferenceSentencesWithSingleVectorResult (
281
- targetResponseFilters ,
282
- modelId ,
283
- inputObjects ,
284
- retryTime + 1 ,
285
- listener
286
- ),
247
+ () -> retryableInferenceSentencesWithSingleVectorResult (inferenceRequest , retryTime + 1 , listener ),
287
248
listener
288
249
)
289
250
));
0 commit comments