From 264abc438d15cc9bbb603da6cd41d5cefce35b2e Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 14 Feb 2025 11:34:16 +0800 Subject: [PATCH] Fix conflicts after rebase main Signed-off-by: zane-neo --- CHANGELOG.md | 1 + .../neuralsearch/ml/MLCommonsClientAccessor.java | 6 ++++-- .../neuralsearch/ml/MLCommonsClientAccessorTests.java | 10 +++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b0f4069e..cc2a255fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements - Set neural-search plugin 3.0.0 baseline JDK version to JDK-2 ([#838](https://github.com/opensearch-project/neural-search/pull/838)) +- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index aa354f641..06cdb6690 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -167,11 +167,13 @@ private void retryableInferenceSentencesWithVectorResult( private void retryableInferenceSimilarityWithVectorResult( final SimilarityInferenceRequest inferenceRequest, final int retryTime, - final ActionListener> listener + final ActionListener> listener ) { MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts()); mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList()); + final List scores = buildVectorFromResponse(mlOutput).stream() + .map(v -> v.getFirst().floatValue()) + .collect(Collectors.toList()); listener.onResponse(scores); }, e -> RetryUtil.handleRetryOrFailure( diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index d1a7b9b69..3fea202d0 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -130,17 +130,17 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener); + accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener); // Verify client.predict is called 4 times (1 initial + 3 retries) Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); // Verify failure is propagated to the listener after all retries - Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException); // Ensure no additional interactions with the listener - Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + Mockito.verifyNoMoreInteractions(similarityResultListener); } public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() { @@ -356,7 +356,7 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener); + accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -372,7 +372,7 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener); + accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));