Skip to content

Commit 6402127

Browse files
committed
Support different embedding types of model response
Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent 0769ad7 commit 6402127

File tree

11 files changed

+1025
-36
lines changed

11 files changed

+1025
-36
lines changed

logs/20240604.txt

Whitespace-only changes.

src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ public class VectorUtil {
2121
* @param vectorAsList {@link List} of {@link Float}'s representing the vector
2222
* @return array of floats produced from input list
2323
*/
24-
public static float[] vectorAsListToArray(List<Float> vectorAsList) {
24+
public static float[] vectorAsListToArray(List<Number> vectorAsList) {
2525
float[] vector = new float[vectorAsList.size()];
2626
for (int i = 0; i < vectorAsList.size(); i++) {
27-
vector[i] = vectorAsList.get(i);
27+
vector[i] = vectorAsList.get(i).floatValue();
2828
}
2929
return vector;
3030
}

src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

+14-14
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public class MLCommonsClientAccessor {
5555
public void inferenceSentence(
5656
@NonNull final String modelId,
5757
@NonNull final String inputText,
58-
@NonNull final ActionListener<List<Float>> listener
58+
@NonNull final ActionListener<List<Number>> listener
5959
) {
6060

6161
inferenceSentences(
@@ -87,7 +87,7 @@ public void inferenceSentence(
8787
*/
8888
public void inferenceSentences(
8989
@NonNull final TextInferenceRequest inferenceRequest,
90-
@NonNull final ActionListener<List<List<Float>>> listener
90+
@NonNull final ActionListener<List<List<Number>>> listener
9191
) {
9292
retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
9393
}
@@ -107,7 +107,7 @@ public void inferenceSentencesWithMapResult(
107107
* @param inferenceRequest {@link InferenceRequest}
108108
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
109109
*/
110-
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
110+
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Number>> listener) {
111111
retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
112112
}
113113

@@ -148,11 +148,11 @@ private void retryableInferenceSentencesWithMapResult(
148148
private void retryableInferenceSentencesWithVectorResult(
149149
final TextInferenceRequest inferenceRequest,
150150
final int retryTime,
151-
final ActionListener<List<List<Float>>> listener
151+
final ActionListener<List<List<Number>>> listener
152152
) {
153153
MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
154154
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
155-
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
155+
final List<List<Number>> vector = buildVectorFromResponse(mlOutput);
156156
listener.onResponse(vector);
157157
},
158158
e -> RetryUtil.handleRetryOrFailure(
@@ -167,11 +167,11 @@ private void retryableInferenceSentencesWithVectorResult(
167167
private void retryableInferenceSimilarityWithVectorResult(
168168
final SimilarityInferenceRequest inferenceRequest,
169169
final int retryTime,
170-
final ActionListener<List<Float>> listener
170+
final ActionListener<List<Number>> listener
171171
) {
172172
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
173173
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
174-
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
174+
final List<Number> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
175175
listener.onResponse(scores);
176176
},
177177
e -> RetryUtil.handleRetryOrFailure(
@@ -194,14 +194,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
194194
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
195195
}
196196

197-
private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
198-
final List<List<Float>> vector = new ArrayList<>();
197+
private <T extends Number> List<List<T>> buildVectorFromResponse(MLOutput mlOutput) {
198+
final List<List<T>> vector = new ArrayList<>();
199199
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
200200
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
201201
for (final ModelTensors tensors : tensorOutputList) {
202202
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
203203
for (final ModelTensor tensor : tensorsList) {
204-
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
204+
vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList()));
205205
}
206206
}
207207
return vector;
@@ -225,19 +225,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
225225
return resultMaps;
226226
}
227227

228-
private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
229-
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
228+
private <T extends Number> List<T> buildSingleVectorFromResponse(final MLOutput mlOutput) {
229+
final List<List<T>> vector = buildVectorFromResponse(mlOutput);
230230
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
231231
}
232232

233233
private void retryableInferenceSentencesWithSingleVectorResult(
234234
final MapInferenceRequest inferenceRequest,
235235
final int retryTime,
236-
final ActionListener<List<Float>> listener
236+
final ActionListener<List<Number>> listener
237237
) {
238238
MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
239239
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
240-
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
240+
final List<Number> vector = buildSingleVectorFromResponse(mlOutput);
241241
log.debug("Inference Response for input sentence is : {} ", vector);
242242
listener.onResponse(vector);
243243
},

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
127127

128128
}
129129

130-
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Float> vectors) {
130+
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Number> vectors) {
131131
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
132132
log.debug("Text embedding result fetched, starting build vector output!");
133133
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors);
@@ -167,7 +167,7 @@ Map<String, String> buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge
167167

168168
@SuppressWarnings({ "unchecked" })
169169
@VisibleForTesting
170-
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Float> modelTensorList) {
170+
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> modelTensorList) {
171171
Map<String, Object> result = new LinkedHashMap<>();
172172
result.put(knnKey, modelTensorList);
173173
return result;

src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
public class VectorUtilTests extends OpenSearchTestCase {
1313

1414
public void testVectorAsListToArray() {
15-
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
15+
List<Number> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
1616
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements);
1717

1818
assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length);
1919
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) {
20-
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f);
20+
assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f);
2121
}
2222

23-
List<Float> vectorAsList_withNoElements = Collections.emptyList();
23+
List<Number> vectorAsList_withNoElements = Collections.emptyList();
2424
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements);
2525
assertEquals(0, vectorAsArray_withNoElements.length);
2626
}

src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java

+13-10
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@
3636
public class MLCommonsClientAccessorTests extends OpenSearchTestCase {
3737

3838
@Mock
39-
private ActionListener<List<List<Float>>> resultListener;
39+
private ActionListener<List<List<Number>>> resultListener;
4040

4141
@Mock
42-
private ActionListener<List<Float>> singleSentenceResultListener;
42+
private ActionListener<List<Number>> singleSentenceResultListener;
43+
44+
@Mock
45+
private ActionListener<List<Float>> similarityResultListener;
4346

4447
@Mock
4548
private MachineLearningNodeClient client;
@@ -53,7 +56,7 @@ public void setup() {
5356
}
5457

5558
public void testInferenceSentence_whenValidInput_thenSuccess() {
56-
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
59+
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
5760
Mockito.doAnswer(invocation -> {
5861
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
5962
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
@@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
6972
}
7073

7174
public void testInferenceSentences_whenValidInputThenSuccess() {
72-
final List<List<Float>> vectorList = new ArrayList<>();
75+
final List<List<Number>> vectorList = new ArrayList<>();
7376
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
7477
Mockito.doAnswer(invocation -> {
7578
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
@@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
8588
}
8689

8790
public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
88-
final List<List<Float>> vectorList = new ArrayList<>();
91+
final List<List<Number>> vectorList = new ArrayList<>();
8992
vectorList.add(Collections.emptyList());
9093
Mockito.doAnswer(invocation -> {
9194
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
@@ -288,7 +291,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
288291
}
289292

290293
public void testInferenceMultimodal_whenValidInput_thenSuccess() {
291-
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
294+
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
292295
Mockito.doAnswer(invocation -> {
293296
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
294297
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
@@ -357,8 +360,8 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
357360

358361
Mockito.verify(client)
359362
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
360-
Mockito.verify(singleSentenceResultListener).onResponse(vector);
361-
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
363+
Mockito.verify(similarityResultListener).onResponse(vector);
364+
Mockito.verifyNoMoreInteractions(similarityResultListener);
362365
}
363366

364367
public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
@@ -373,8 +376,8 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
373376

374377
Mockito.verify(client)
375378
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
376-
Mockito.verify(singleSentenceResultListener).onFailure(exception);
377-
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
379+
Mockito.verify(similarityResultListener).onFailure(exception);
380+
Mockito.verifyNoMoreInteractions(similarityResultListener);
378381
}
379382

380383
private ModelTensorOutput createModelTensorOutput(final Float[] output) {

src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -771,10 +771,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() {
771771
.modelId(MODEL_ID)
772772
.k(K)
773773
.build();
774-
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
774+
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
775775
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
776776
doAnswer(invocation -> {
777-
ActionListener<List<Float>> listener = invocation.getArgument(1);
777+
ActionListener<List<Number>> listener = invocation.getArgument(1);
778778
listener.onResponse(expectedVector);
779779
return null;
780780
}).when(mlCommonsClientAccessor)
@@ -810,10 +810,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe
810810
.modelId(MODEL_ID)
811811
.k(K)
812812
.build();
813-
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
813+
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
814814
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
815815
doAnswer(invocation -> {
816-
ActionListener<List<Float>> listener = invocation.getArgument(1);
816+
ActionListener<List<Number>> listener = invocation.getArgument(1);
817817
listener.onResponse(expectedVector);
818818
return null;
819819
}).when(mlCommonsClientAccessor)

src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ protected float[] runInference(final String modelId, final String queryText) {
312312
List<Object> output = (List<Object>) result.get("output");
313313
assertEquals(1, output.size());
314314
Map<String, Object> map = (Map<String, Object>) output.get(0);
315-
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
315+
List<Number> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
316316
return vectorAsListToArray(data);
317317
}
318318

0 commit comments

Comments
 (0)