Skip to content

Commit 56fc80a

Browse files
committed
Support different embedding types of model response
Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent 3c7f275 commit 56fc80a

File tree

7 files changed

+44
-40
lines changed

7 files changed

+44
-40
lines changed

src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ public class VectorUtil {
2121
* @param vectorAsList {@link List} of {@link Float}'s representing the vector
2222
* @return array of floats produced from input list
2323
*/
24-
public static float[] vectorAsListToArray(List<Float> vectorAsList) {
24+
public static float[] vectorAsListToArray(List<Number> vectorAsList) {
2525
float[] vector = new float[vectorAsList.size()];
2626
for (int i = 0; i < vectorAsList.size(); i++) {
27-
vector[i] = vectorAsList.get(i);
27+
vector[i] = vectorAsList.get(i).floatValue();
2828
}
2929
return vector;
3030
}

src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

+15-14
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public class MLCommonsClientAccessor {
5252
public void inferenceSentence(
5353
@NonNull final String modelId,
5454
@NonNull final String inputText,
55-
@NonNull final ActionListener<List<Float>> listener
55+
@NonNull final ActionListener<List<Number>> listener
5656
) {
5757
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> {
5858
if (response.size() != 1) {
@@ -82,7 +82,7 @@ public void inferenceSentence(
8282
public void inferenceSentences(
8383
@NonNull final String modelId,
8484
@NonNull final List<String> inputText,
85-
@NonNull final ActionListener<List<List<Float>>> listener
85+
@NonNull final ActionListener<List<List<Number>>> listener
8686
) {
8787
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener);
8888
}
@@ -103,7 +103,7 @@ public void inferenceSentences(
103103
@NonNull final List<String> targetResponseFilters,
104104
@NonNull final String modelId,
105105
@NonNull final List<String> inputText,
106-
@NonNull final ActionListener<List<List<Float>>> listener
106+
@NonNull final ActionListener<List<List<Number>>> listener
107107
) {
108108
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
109109
}
@@ -128,7 +128,7 @@ public void inferenceSentencesWithMapResult(
128128
public void inferenceSentences(
129129
@NonNull final String modelId,
130130
@NonNull final Map<String, String> inputObjects,
131-
@NonNull final ActionListener<List<Float>> listener
131+
@NonNull final ActionListener<List<Number>> listener
132132
) {
133133
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
134134
}
@@ -177,11 +177,11 @@ private void retryableInferenceSentencesWithVectorResult(
177177
final String modelId,
178178
final List<String> inputText,
179179
final int retryTime,
180-
final ActionListener<List<List<Float>>> listener
180+
final ActionListener<List<List<Number>>> listener
181181
) {
182182
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
183183
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
184-
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
184+
final List<List<Number>> vector = buildVectorFromResponse(mlOutput);
185185
listener.onResponse(vector);
186186
}, e -> {
187187
if (RetryUtil.shouldRetry(e, retryTime)) {
@@ -202,7 +202,8 @@ private void retryableInferenceSimilarityWithVectorResult(
202202
) {
203203
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
204204
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
205-
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
205+
final List<List<Float>> tensors = buildVectorFromResponse(mlOutput);
206+
final List<Float> scores = tensors.stream().map(v -> v.get(0)).collect(Collectors.toList());
206207
listener.onResponse(scores);
207208
}, e -> {
208209
if (RetryUtil.shouldRetry(e, retryTime)) {
@@ -224,14 +225,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
224225
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
225226
}
226227

227-
private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
228-
final List<List<Float>> vector = new ArrayList<>();
228+
private <T extends Number> List<List<T>> buildVectorFromResponse(MLOutput mlOutput) {
229+
final List<List<T>> vector = new ArrayList<>();
229230
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
230231
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
231232
for (final ModelTensors tensors : tensorOutputList) {
232233
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
233234
for (final ModelTensor tensor : tensorsList) {
234-
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
235+
vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList()));
235236
}
236237
}
237238
return vector;
@@ -255,8 +256,8 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
255256
return resultMaps;
256257
}
257258

258-
private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
259-
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
259+
private <T extends Number> List<T> buildSingleVectorFromResponse(final MLOutput mlOutput) {
260+
final List<List<T>> vector = buildVectorFromResponse(mlOutput);
260261
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
261262
}
262263

@@ -265,11 +266,11 @@ private void retryableInferenceSentencesWithSingleVectorResult(
265266
final String modelId,
266267
final Map<String, String> inputObjects,
267268
final int retryTime,
268-
final ActionListener<List<Float>> listener
269+
final ActionListener<List<Number>> listener
269270
) {
270271
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
271272
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
272-
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
273+
final List<Number> vector = buildSingleVectorFromResponse(mlOutput);
273274
log.debug("Inference Response for input sentence is : {} ", vector);
274275
listener.onResponse(vector);
275276
}, e -> {

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
124124

125125
}
126126

127-
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Float> vectors) {
127+
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Number> vectors) {
128128
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
129129
log.debug("Text embedding result fetched, starting build vector output!");
130130
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors);
@@ -164,7 +164,7 @@ Map<String, String> buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge
164164

165165
@SuppressWarnings({ "unchecked" })
166166
@VisibleForTesting
167-
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Float> modelTensorList) {
167+
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> modelTensorList) {
168168
Map<String, Object> result = new LinkedHashMap<>();
169169
result.put(knnKey, modelTensorList);
170170
return result;

src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
public class VectorUtilTests extends OpenSearchTestCase {
1313

1414
public void testVectorAsListToArray() {
15-
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
15+
List<Number> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
1616
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements);
1717

1818
assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length);
1919
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) {
20-
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f);
20+
assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f);
2121
}
2222

23-
List<Float> vectorAsList_withNoElements = Collections.emptyList();
23+
List<Number> vectorAsList_withNoElements = Collections.emptyList();
2424
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements);
2525
assertEquals(0, vectorAsArray_withNoElements.length);
2626
}

src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java

+17-14
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@
3636
public class MLCommonsClientAccessorTests extends OpenSearchTestCase {
3737

3838
@Mock
39-
private ActionListener<List<List<Float>>> resultListener;
39+
private ActionListener<List<List<Number>>> resultListener;
4040

4141
@Mock
42-
private ActionListener<List<Float>> singleSentenceResultListener;
42+
private ActionListener<List<Number>> singleSentenceResultListener;
43+
44+
@Mock
45+
private ActionListener<List<Float>> similarityResultListener;
4346

4447
@Mock
4548
private MachineLearningNodeClient client;
@@ -53,7 +56,7 @@ public void setup() {
5356
}
5457

5558
public void testInferenceSentence_whenValidInput_thenSuccess() {
56-
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
59+
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
5760
Mockito.doAnswer(invocation -> {
5861
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
5962
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
@@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
6972
}
7073

7174
public void testInferenceSentences_whenValidInputThenSuccess() {
72-
final List<List<Float>> vectorList = new ArrayList<>();
75+
final List<List<Number>> vectorList = new ArrayList<>();
7376
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
7477
Mockito.doAnswer(invocation -> {
7578
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
@@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
8588
}
8689

8790
public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
88-
final List<List<Float>> vectorList = new ArrayList<>();
91+
final List<List<Number>> vectorList = new ArrayList<>();
8992
vectorList.add(Collections.emptyList());
9093
Mockito.doAnswer(invocation -> {
9194
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
@@ -278,7 +281,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
278281
}
279282

280283
public void testInferenceMultimodal_whenValidInput_thenSuccess() {
281-
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
284+
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
282285
Mockito.doAnswer(invocation -> {
283286
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
284287
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
@@ -337,13 +340,13 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
337340
TestCommonConstants.MODEL_ID,
338341
"is it sunny",
339342
List.of("it is sunny today", "roses are red"),
340-
singleSentenceResultListener
343+
similarityResultListener
341344
);
342345

343346
Mockito.verify(client)
344347
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
345-
Mockito.verify(singleSentenceResultListener).onResponse(vector);
346-
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
348+
Mockito.verify(similarityResultListener).onResponse(vector);
349+
Mockito.verifyNoMoreInteractions(similarityResultListener);
347350
}
348351

349352
public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
@@ -358,13 +361,13 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
358361
TestCommonConstants.MODEL_ID,
359362
"is it sunny",
360363
List.of("it is sunny today", "roses are red"),
361-
singleSentenceResultListener
364+
similarityResultListener
362365
);
363366

364367
Mockito.verify(client)
365368
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
366-
Mockito.verify(singleSentenceResultListener).onFailure(exception);
367-
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
369+
Mockito.verify(similarityResultListener).onFailure(exception);
370+
Mockito.verifyNoMoreInteractions(similarityResultListener);
368371
}
369372

370373
public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() {
@@ -382,12 +385,12 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi
382385
TestCommonConstants.MODEL_ID,
383386
"is it sunny",
384387
List.of("it is sunny today", "roses are red"),
385-
singleSentenceResultListener
388+
similarityResultListener
386389
);
387390

388391
Mockito.verify(client, times(4))
389392
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
390-
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
393+
Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException);
391394
}
392395

393396
private ModelTensorOutput createModelTensorOutput(final Float[] output) {

src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -646,10 +646,10 @@ public void testHashAndEquals() {
646646
@SneakyThrows
647647
public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() {
648648
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K);
649-
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
649+
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
650650
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
651651
doAnswer(invocation -> {
652-
ActionListener<List<Float>> listener = invocation.getArgument(2);
652+
ActionListener<List<Number>> listener = invocation.getArgument(2);
653653
listener.onResponse(expectedVector);
654654
return null;
655655
}).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any());
@@ -682,10 +682,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe
682682
.queryImage(IMAGE_TEXT)
683683
.modelId(MODEL_ID)
684684
.k(K);
685-
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
685+
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
686686
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
687687
doAnswer(invocation -> {
688-
ActionListener<List<Float>> listener = invocation.getArgument(2);
688+
ActionListener<List<Number>> listener = invocation.getArgument(2);
689689
listener.onResponse(expectedVector);
690690
return null;
691691
}).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any());

src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ protected float[] runInference(final String modelId, final String queryText) {
268268
List<Object> output = (List<Object>) result.get("output");
269269
assertEquals(1, output.size());
270270
Map<String, Object> map = (Map<String, Object>) output.get(0);
271-
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
271+
List<Number> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
272272
return vectorAsListToArray(data);
273273
}
274274

0 commit comments

Comments
 (0)