Skip to content

Commit 3ea3562

Browse files
committed
Update fetchModelResults name and optimize highlight result handle
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
1 parent a272d56 commit 3ea3562

File tree

3 files changed

+29
-43
lines changed

3 files changed

+29
-43
lines changed

src/main/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterManager.java

+26-40
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public String getModelId(Map<String, Object> options) {
104104
* @return Formatted text with highlighting
105105
*/
106106
public String getHighlightedSentences(String modelId, String question, String context) {
107-
List<Map<String, Object>> result = fetchHighlightingResults(modelId, question, context);
107+
List<Map<String, Object>> result = fetchModelResults(modelId, question, context);
108108
if (result == null) {
109109
return StringUtils.EMPTY;
110110
}
@@ -120,7 +120,7 @@ public String getHighlightedSentences(String modelId, String question, String co
120120
* @param context The document text
121121
* @return The highlighting results
122122
*/
123-
public List<Map<String, Object>> fetchHighlightingResults(String modelId, String question, String context) {
123+
public List<Map<String, Object>> fetchModelResults(String modelId, String question, String context) {
124124

125125
CountDownLatch latch = new CountDownLatch(1);
126126
AtomicReference<List<Map<String, Object>>> resultRef = new AtomicReference<>();
@@ -165,44 +165,31 @@ public List<Map<String, Object>> fetchHighlightingResults(String modelId, String
165165
* @return Formatted text with highlighting
166166
*/
167167
public String applyHighlighting(String context, List<Map<String, Object>> highlightResults) {
168-
// Collect all valid highlight positions
169168
List<int[]> validHighlights = new ArrayList<>();
170169

171-
// Process each highlight result
172-
for (Map<String, Object> result : highlightResults) {
173-
if (result == null) {
174-
continue;
175-
}
176-
177-
// Get the "highlights" list from the result
178-
Object highlightsObj = result.get(MODEL_INFERENCE_RESULT_KEY);
179-
180-
// Safely check if the object is a List
181-
if (!(highlightsObj instanceof List<?> highlightsList)) {
182-
continue;
183-
}
184-
185-
// Process the list safely
186-
if (highlightsList.isEmpty()) {
187-
continue;
188-
}
189-
190-
// Process each item in the list, checking if it's a Map
191-
for (Object item : highlightsList) {
192-
if (item instanceof Map<?, ?> map) {
193-
// Create a type-safe map
194-
Map<String, Object> safeMap = new java.util.HashMap<>();
195-
for (Map.Entry<?, ?> entry : map.entrySet()) {
196-
safeMap.put(entry.getKey().toString(), entry.getValue());
197-
}
198-
199-
// Extract and validate positions
200-
Object startObj = safeMap.get(MODEL_INFERENCE_RESULT_START_KEY);
201-
Object endObj = safeMap.get(MODEL_INFERENCE_RESULT_END_KEY);
202-
203-
int[] positions = validateHighlightPositions(startObj, endObj, context.length());
204-
if (positions != null) {
205-
validHighlights.add(positions);
170+
if (highlightResults != null && !highlightResults.isEmpty()) {
171+
Map<String, Object> result = highlightResults.getFirst();
172+
if (result != null) {
173+
// Get the "highlights" list from the result
174+
Object highlightsObj = result.get(MODEL_INFERENCE_RESULT_KEY);
175+
176+
if (highlightsObj instanceof List<?> highlightsList) {
177+
for (Object item : highlightsList) {
178+
if (item instanceof Map<?, ?> map) {
179+
Map<String, Object> safeMap = new java.util.HashMap<>();
180+
for (Map.Entry<?, ?> entry : map.entrySet()) {
181+
safeMap.put(entry.getKey().toString(), entry.getValue());
182+
}
183+
184+
// Extract and validate positions
185+
Object startObj = safeMap.get(MODEL_INFERENCE_RESULT_START_KEY);
186+
Object endObj = safeMap.get(MODEL_INFERENCE_RESULT_END_KEY);
187+
188+
int[] positions = validateHighlightPositions(startObj, endObj, context.length());
189+
if (positions != null) {
190+
validHighlights.add(positions);
191+
}
192+
}
206193
}
207194
}
208195
}
@@ -216,7 +203,6 @@ public String applyHighlighting(String context, List<Map<String, Object>> highli
216203
// Sort highlights by start position (ascending)
217204
validHighlights.sort(Comparator.comparingInt(pos -> pos[0]));
218205

219-
// Construct the highlighted text in O(n) time
220206
return constructHighlightedText(context, validHighlights);
221207
}
222208

@@ -273,7 +259,7 @@ private List<int[]> mergeOverlappingHighlights(List<int[]> highlights) {
273259
}
274260

275261
List<int[]> merged = new ArrayList<>();
276-
int[] current = highlights.get(0);
262+
int[] current = highlights.getFirst();
277263
merged.add(current);
278264

279265
for (int i = 1; i < highlights.size(); i++) {

src/test/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterManagerTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ public void testApplyHighlightingWithInvalidPositions() {
214214
assertEquals("Should only apply valid highlights", "<em>This</em> is a test string", result);
215215
}
216216

217-
public void testFetchHighlightingResultsWithTimeout() throws Exception {
217+
public void testFetchModelResultsWithTimeout() throws Exception {
218218
// Create a custom mock that delays the response
219219
MLCommonsClientAccessor delayedMlClient = mock(MLCommonsClientAccessor.class);
220220
NeuralHighlighterManager customManager = new NeuralHighlighterManager(delayedMlClient);
@@ -259,7 +259,7 @@ public void testFetchHighlightingResultsWithTimeout() throws Exception {
259259
// Call the method in a separate thread so we can interrupt it
260260
Thread testThread = new Thread(() -> {
261261
try {
262-
customManager.fetchHighlightingResults(MODEL_ID, TEST_QUERY, TEST_CONTENT);
262+
customManager.fetchModelResults(MODEL_ID, TEST_QUERY, TEST_CONTENT);
263263
fail("Should have been interrupted");
264264
} catch (OpenSearchException e) {
265265
// Expected exception

src/test/java/org/opensearch/neuralsearch/highlight/NeuralHighlighterTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ public void testFetchHighlightingResultsWithTimeout() throws Exception {
398398
// Call the method in a separate thread so we can interrupt it
399399
Thread testThread = new Thread(() -> {
400400
try {
401-
customManager.fetchHighlightingResults(MODEL_ID, TEST_QUERY, TEST_CONTENT);
401+
customManager.fetchModelResults(MODEL_ID, TEST_QUERY, TEST_CONTENT);
402402
fail("Should have been interrupted");
403403
} catch (OpenSearchException e) {
404404
// Expected exception

0 commit comments

Comments
 (0)