Skip to content

Commit ebfb058

Browse files
fen-qinFen Qin
and
Fen Qin
authored
code refactoring on MLCommonsClientAccessor request (#1178)
* code refactoring on MLCommonsClientAccessor request Signed-off-by: Fen Qin <mfenqin@amazon.com> * fix unit tests Signed-off-by: Fen Qin <mfenqin@amazon.com> * Implements inheritance pattern to break down inference request Signed-off-by: Fen Qin <mfenqin@amazon.com> --------- Signed-off-by: Fen Qin <mfenqin@amazon.com> Co-authored-by: Fen Qin <“mfenqin@amazon.com”>
1 parent 7dc84b5 commit ebfb058

20 files changed

+364
-251
lines changed

src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

+46-85
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
import org.opensearch.ml.common.output.model.ModelTensor;
2727
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2828
import org.opensearch.ml.common.output.model.ModelTensors;
29+
import org.opensearch.neuralsearch.processor.InferenceRequest;
30+
import org.opensearch.neuralsearch.processor.MapInferenceRequest;
31+
import org.opensearch.neuralsearch.processor.SimilarityInferenceRequest;
32+
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
2933
import org.opensearch.neuralsearch.util.RetryUtil;
3034

3135
import lombok.NonNull;
@@ -38,53 +42,37 @@
3842
@RequiredArgsConstructor
3943
@Log4j2
4044
public class MLCommonsClientAccessor {
41-
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
4245
private final MachineLearningNodeClient mlClient;
4346

4447
/**
4548
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
4649
* point vector as a response.
4750
*
4851
* @param modelId {@link String}
49-
* @param inputText {@link List} of {@link String} on which inference needs to happen
52+
* @param inputText {@link String}
5053
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
5154
*/
5255
public void inferenceSentence(
5356
@NonNull final String modelId,
5457
@NonNull final String inputText,
5558
@NonNull final ActionListener<List<Float>> listener
5659
) {
57-
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> {
58-
if (response.size() != 1) {
59-
listener.onFailure(
60-
new IllegalStateException(
61-
"Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"
62-
)
63-
);
64-
return;
65-
}
6660

67-
listener.onResponse(response.get(0));
68-
}, listener::onFailure));
69-
}
61+
inferenceSentences(
62+
TextInferenceRequest.builder().modelId(modelId).inputTexts(List.of(inputText)).build(),
63+
ActionListener.wrap(response -> {
64+
if (response.size() != 1) {
65+
listener.onFailure(
66+
new IllegalStateException(
67+
"Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"
68+
)
69+
);
70+
return;
71+
}
7072

71-
/**
72-
* Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the
73-
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
74-
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
75-
* inputText. We are not making this function generic enough to take any function or TaskType as currently we
76-
* need to run only TextEmbedding tasks only.
77-
*
78-
* @param modelId {@link String}
79-
* @param inputText {@link List} of {@link String} on which inference needs to happen
80-
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
81-
*/
82-
public void inferenceSentences(
83-
@NonNull final String modelId,
84-
@NonNull final List<String> inputText,
85-
@NonNull final ActionListener<List<List<Float>>> listener
86-
) {
87-
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener);
73+
listener.onResponse(response.getFirst());
74+
}, listener::onFailure)
75+
);
8876
}
8977

9078
/**
@@ -94,121 +82,102 @@ public void inferenceSentences(
9482
* inputText. We are not making this function generic enough to take any function or TaskType as currently we
9583
* need to run only TextEmbedding tasks only.
9684
*
97-
* @param targetResponseFilters {@link List} of {@link String} which filters out the responses
98-
* @param modelId {@link String}
99-
* @param inputText {@link List} of {@link String} on which inference needs to happen
85+
* @param inferenceRequest {@link InferenceRequest}
10086
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
10187
*/
10288
public void inferenceSentences(
103-
@NonNull final List<String> targetResponseFilters,
104-
@NonNull final String modelId,
105-
@NonNull final List<String> inputText,
89+
@NonNull final TextInferenceRequest inferenceRequest,
10690
@NonNull final ActionListener<List<List<Float>>> listener
10791
) {
108-
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
92+
retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
10993
}
11094

11195
public void inferenceSentencesWithMapResult(
112-
@NonNull final String modelId,
113-
@NonNull final List<String> inputText,
96+
@NonNull final TextInferenceRequest inferenceRequest,
11497
@NonNull final ActionListener<List<Map<String, ?>>> listener
11598
) {
116-
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
99+
retryableInferenceSentencesWithMapResult(inferenceRequest, 0, listener);
117100
}
118101

119102
/**
120103
* Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the
121104
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
122105
* using the actionListener which will have a list of floats in the order of inputText.
123106
*
124-
* @param modelId {@link String}
125-
* @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen
107+
* @param inferenceRequest {@link InferenceRequest}
126108
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
127109
*/
128-
public void inferenceSentences(
129-
@NonNull final String modelId,
130-
@NonNull final Map<String, String> inputObjects,
131-
@NonNull final ActionListener<List<Float>> listener
132-
) {
133-
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
110+
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
111+
retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
134112
}
135113

136114
/**
137115
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
138116
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
139117
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
140118
*
141-
* @param modelId {@link String} ML-Commons Model Id
142-
* @param queryText {@link String} The query to compare all the inputText to
143-
* @param inputText {@link List} of {@link String} The texts to compare to the query
119+
* @param inferenceRequest {@link InferenceRequest}
144120
* @param listener {@link ActionListener} receives the result of the inference
145121
*/
146122
public void inferenceSimilarity(
147-
@NonNull final String modelId,
148-
@NonNull final String queryText,
149-
@NonNull final List<String> inputText,
123+
@NonNull SimilarityInferenceRequest inferenceRequest,
150124
@NonNull final ActionListener<List<Float>> listener
151125
) {
152-
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
126+
retryableInferenceSimilarityWithVectorResult(inferenceRequest, 0, listener);
153127
}
154128

155129
private void retryableInferenceSentencesWithMapResult(
156-
final String modelId,
157-
final List<String> inputText,
130+
final TextInferenceRequest inferenceRequest,
158131
final int retryTime,
159132
final ActionListener<List<Map<String, ?>>> listener
160133
) {
161-
MLInput mlInput = createMLTextInput(null, inputText);
162-
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
134+
MLInput mlInput = createMLTextInput(null, inferenceRequest.getInputTexts());
135+
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
163136
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
164137
listener.onResponse(result);
165138
},
166139
e -> RetryUtil.handleRetryOrFailure(
167140
e,
168141
retryTime,
169-
() -> retryableInferenceSentencesWithMapResult(modelId, inputText, retryTime + 1, listener),
142+
() -> retryableInferenceSentencesWithMapResult(inferenceRequest, retryTime + 1, listener),
170143
listener
171144
)
172145
));
173146
}
174147

175148
private void retryableInferenceSentencesWithVectorResult(
176-
final List<String> targetResponseFilters,
177-
final String modelId,
178-
final List<String> inputText,
149+
final TextInferenceRequest inferenceRequest,
179150
final int retryTime,
180151
final ActionListener<List<List<Float>>> listener
181152
) {
182-
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
183-
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
153+
MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
154+
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
184155
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
185156
listener.onResponse(vector);
186157
},
187158
e -> RetryUtil.handleRetryOrFailure(
188159
e,
189160
retryTime,
190-
() -> retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTime + 1, listener),
161+
() -> retryableInferenceSentencesWithVectorResult(inferenceRequest, retryTime + 1, listener),
191162
listener
192163
)
193164
));
194165
}
195166

196167
private void retryableInferenceSimilarityWithVectorResult(
197-
final String modelId,
198-
final String queryText,
199-
final List<String> inputText,
168+
final SimilarityInferenceRequest inferenceRequest,
200169
final int retryTime,
201170
final ActionListener<List<Float>> listener
202171
) {
203-
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
204-
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
172+
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
173+
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
205174
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
206175
listener.onResponse(scores);
207176
},
208177
e -> RetryUtil.handleRetryOrFailure(
209178
e,
210179
retryTime,
211-
() -> retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener),
180+
() -> retryableInferenceSimilarityWithVectorResult(inferenceRequest, retryTime + 1, listener),
212181
listener
213182
)
214183
));
@@ -262,28 +231,20 @@ private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
262231
}
263232

264233
private void retryableInferenceSentencesWithSingleVectorResult(
265-
final List<String> targetResponseFilters,
266-
final String modelId,
267-
final Map<String, String> inputObjects,
234+
final MapInferenceRequest inferenceRequest,
268235
final int retryTime,
269236
final ActionListener<List<Float>> listener
270237
) {
271-
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
272-
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
238+
MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
239+
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
273240
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
274241
log.debug("Inference Response for input sentence is : {} ", vector);
275242
listener.onResponse(vector);
276243
},
277244
e -> RetryUtil.handleRetryOrFailure(
278245
e,
279246
retryTime,
280-
() -> retryableInferenceSentencesWithSingleVectorResult(
281-
targetResponseFilters,
282-
modelId,
283-
inputObjects,
284-
retryTime + 1,
285-
listener
286-
),
247+
() -> retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, retryTime + 1, listener),
287248
listener
288249
)
289250
));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import java.util.List;
8+
9+
import lombok.Builder;
10+
import lombok.Getter;
11+
import lombok.NoArgsConstructor;
12+
import lombok.NonNull;
13+
import lombok.Setter;
14+
import lombok.experimental.SuperBuilder;
15+
16+
@SuperBuilder
17+
@NoArgsConstructor
18+
@Getter
19+
@Setter
20+
/**
21+
* Base abstract class for inference requests.
22+
* This class contains common fields and behaviors shared across different types of inference requests.
23+
*/
24+
public abstract class InferenceRequest {
25+
/**
26+
* Unique identifier for the model to be used for inference.
27+
* This field is required and cannot be null.
28+
*/
29+
@NonNull
30+
private String modelId;
31+
/**
32+
* List of targetResponseFilters to be applied.
33+
* Defaults value if not specified.
34+
*/
35+
@Builder.Default
36+
private List<String> targetResponseFilters = List.of("sentence_embedding");
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import java.util.Map;
8+
import lombok.Getter;
9+
import lombok.NoArgsConstructor;
10+
import lombok.Setter;
11+
import lombok.experimental.SuperBuilder;
12+
13+
/**
14+
* Implementation of InferenceRequest for inputObjects based inference requests.
15+
* Use this class when the input data consists of key-value pairs.
16+
*
17+
* @see InferenceRequest
18+
*/
19+
@SuperBuilder
20+
@NoArgsConstructor
21+
@Getter
22+
@Setter
23+
public class MapInferenceRequest extends InferenceRequest {
24+
private Map<String, String> inputObjects;
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import lombok.NoArgsConstructor;
8+
import lombok.Getter;
9+
import lombok.Setter;
10+
import lombok.experimental.SuperBuilder;
11+
12+
/**
13+
* Implementation of InferenceRequest for similarity based text inference requests.
14+
*
15+
* @see TextInferenceRequest
16+
*/
17+
@SuperBuilder
18+
@NoArgsConstructor
19+
@Getter
20+
@Setter
21+
public class SimilarityInferenceRequest extends TextInferenceRequest {
22+
private String queryText;
23+
}

src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java

+21-15
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,30 @@ public void doExecute(
5959
List<String> inferenceList,
6060
BiConsumer<IngestDocument, Exception> handler
6161
) {
62-
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
63-
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
64-
.stream()
65-
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
66-
.toList();
67-
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
68-
handler.accept(ingestDocument, null);
69-
}, e -> { handler.accept(null, e); }));
62+
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
63+
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
64+
ActionListener.wrap(resultMaps -> {
65+
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
66+
.stream()
67+
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
68+
.toList();
69+
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
70+
handler.accept(ingestDocument, null);
71+
}, e -> { handler.accept(null, e); })
72+
);
7073
}
7174

7275
@Override
7376
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
74-
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
75-
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
76-
.stream()
77-
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
78-
.toList();
79-
handler.accept(sparseVectors);
80-
}, onException));
77+
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
78+
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
79+
ActionListener.wrap(resultMaps -> {
80+
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
81+
.stream()
82+
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
83+
.toList();
84+
handler.accept(sparseVectors);
85+
}, onException)
86+
);
8187
}
8288
}

0 commit comments

Comments
 (0)