Skip to content

Commit cc6a6b2

Browse files
Add support for local cache in hybrid query (#663)
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
1 parent 50a6dcf commit cc6a6b2

File tree

5 files changed

+405
-12
lines changed

5 files changed

+405
-12
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
### Features
2020
### Enhancements
2121
### Bug Fixes
22+
- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663))
2223
### Infrastructure
2324
### Documentation
2425
### Maintenance

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

+11-3
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,13 @@ private void updateOriginalFetchResults(
138138
// 3. update original scores to normalized and combined values
139139
// 4. order scores based on normalized and combined values
140140
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
141-
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);
141+
// checking case when results are cached
142+
boolean requestCache = Objects.nonNull(querySearchResults)
143+
&& !querySearchResults.isEmpty()
144+
&& Objects.nonNull(querySearchResults.get(0).getShardSearchRequest().requestCache())
145+
&& querySearchResults.get(0).getShardSearchRequest().requestCache();
146+
147+
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult, requestCache);
142148

143149
// create map of docId to index of search hits. This solves (2), duplicates are from
144150
// delimiter and start/stop elements, they all have same valid doc_id. For this map
@@ -168,7 +174,7 @@ private void updateOriginalFetchResults(
168174
fetchSearchResult.hits(updatedSearchHits);
169175
}
170176

171-
private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
177+
private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult, final boolean requestCache) {
172178
SearchHits searchHits = fetchSearchResult.hits();
173179
SearchHit[] searchHitArray = searchHits.getHits();
174180
// validate the both collections are of the same size
@@ -177,7 +183,9 @@ private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchR
177183
"score normalization processor cannot produce final query result, fetch query phase returns empty results"
178184
);
179185
}
180-
if (searchHitArray.length != docIds.size()) {
186+
// in case of cached request results of fetch and query may be different, only restriction is
187+
// that number of query results size is greater or equal size of fetch results
188+
if ((!requestCache && searchHitArray.length != docIds.size()) || requestCache && docIds.size() < searchHitArray.length) {
181189
throw new IllegalStateException(
182190
String.format(
183191
Locale.ROOT,

src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java

+7
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
5454
import org.opensearch.search.fetch.FetchSearchResult;
5555
import org.opensearch.search.fetch.QueryFetchSearchResult;
56+
import org.opensearch.search.internal.ShardSearchRequest;
5657
import org.opensearch.search.query.QuerySearchResult;
5758
import org.opensearch.test.OpenSearchTestCase;
5859
import org.opensearch.threadpool.TestThreadPool;
@@ -401,6 +402,9 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz
401402

402403
QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
403404
queryFetchSearchResult.setShardIndex(shardId);
405+
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
406+
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
407+
querySearchResult.setShardSearchRequest(shardSearchRequest);
404408

405409
queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);
406410

@@ -485,6 +489,9 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail()
485489

486490
QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
487491
queryFetchSearchResult.setShardIndex(shardId);
492+
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
493+
when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE);
494+
querySearchResult.setShardSearchRequest(shardSearchRequest);
488495

489496
queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);
490497

src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java

+78-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
*/
55
package org.opensearch.neuralsearch.processor;
66

7+
import static org.mockito.Mockito.mock;
78
import static org.mockito.Mockito.spy;
9+
import static org.mockito.Mockito.when;
810
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
911
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;
1012

@@ -29,6 +31,7 @@
2931
import org.opensearch.search.SearchHits;
3032
import org.opensearch.search.SearchShardTarget;
3133
import org.opensearch.search.fetch.FetchSearchResult;
34+
import org.opensearch.search.internal.ShardSearchRequest;
3235
import org.opensearch.search.query.QuerySearchResult;
3336
import org.opensearch.test.OpenSearchTestCase;
3437

@@ -156,6 +159,9 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo
156159
);
157160
querySearchResult.setSearchShardTarget(searchShardTarget);
158161
querySearchResult.setShardIndex(shardId);
162+
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
163+
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
164+
querySearchResult.setShardSearchRequest(shardSearchRequest);
159165
querySearchResults.add(querySearchResult);
160166
SearchHit[] searchHitArray = new SearchHit[] {
161167
new SearchHit(0, "10", Map.of(), Map.of()),
@@ -213,6 +219,9 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
213219
);
214220
querySearchResult.setSearchShardTarget(searchShardTarget);
215221
querySearchResult.setShardIndex(shardId);
222+
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
223+
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
224+
querySearchResult.setShardSearchRequest(shardSearchRequest);
216225
querySearchResults.add(querySearchResult);
217226
SearchHit[] searchHitArray = new SearchHit[] {
218227
new SearchHit(-1, "10", Map.of(), Map.of()),
@@ -236,7 +245,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
236245
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
237246
}
238247

239-
public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
248+
public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
240249
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
241250
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
242251
);
@@ -270,15 +279,11 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
270279
);
271280
querySearchResult.setSearchShardTarget(searchShardTarget);
272281
querySearchResult.setShardIndex(shardId);
282+
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
283+
when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE);
284+
querySearchResult.setShardSearchRequest(shardSearchRequest);
273285
querySearchResults.add(querySearchResult);
274-
SearchHit[] searchHitArray = new SearchHit[] {
275-
new SearchHit(-1, "10", Map.of(), Map.of()),
276-
new SearchHit(-1, "10", Map.of(), Map.of()),
277-
new SearchHit(-1, "10", Map.of(), Map.of()),
278-
new SearchHit(-1, "1", Map.of(), Map.of()),
279-
new SearchHit(-1, "2", Map.of(), Map.of()),
280-
new SearchHit(-1, "3", Map.of(), Map.of()) };
281-
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
286+
SearchHits searchHits = getSearchHits();
282287
fetchSearchResult.hits(searchHits);
283288

284289
expectThrows(
@@ -291,4 +296,68 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
291296
)
292297
);
293298
}
299+
300+
public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() {
301+
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
302+
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
303+
);
304+
305+
List<QuerySearchResult> querySearchResults = new ArrayList<>();
306+
FetchSearchResult fetchSearchResult = new FetchSearchResult();
307+
int shardId = 0;
308+
SearchShardTarget searchShardTarget = new SearchShardTarget(
309+
"node",
310+
new ShardId("index", "uuid", shardId),
311+
null,
312+
OriginalIndices.NONE
313+
);
314+
QuerySearchResult querySearchResult = new QuerySearchResult();
315+
querySearchResult.topDocs(
316+
new TopDocsAndMaxScore(
317+
new TopDocs(
318+
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
319+
new ScoreDoc[] {
320+
createStartStopElementForHybridSearchResults(0),
321+
createDelimiterElementForHybridSearchResults(0),
322+
new ScoreDoc(0, 0.5f),
323+
new ScoreDoc(2, 0.3f),
324+
new ScoreDoc(4, 0.25f),
325+
new ScoreDoc(10, 0.2f),
326+
createStartStopElementForHybridSearchResults(0) }
327+
),
328+
0.5f
329+
),
330+
new DocValueFormat[0]
331+
);
332+
querySearchResult.setSearchShardTarget(searchShardTarget);
333+
querySearchResult.setShardIndex(shardId);
334+
ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class);
335+
when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE);
336+
querySearchResult.setShardSearchRequest(shardSearchRequest);
337+
querySearchResults.add(querySearchResult);
338+
SearchHits searchHits = getSearchHits();
339+
fetchSearchResult.hits(searchHits);
340+
341+
normalizationProcessorWorkflow.execute(
342+
querySearchResults,
343+
Optional.of(fetchSearchResult),
344+
ScoreNormalizationFactory.DEFAULT_METHOD,
345+
ScoreCombinationFactory.DEFAULT_METHOD
346+
);
347+
348+
TestUtils.assertQueryResultScores(querySearchResults);
349+
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
350+
}
351+
352+
private static SearchHits getSearchHits() {
353+
SearchHit[] searchHitArray = new SearchHit[] {
354+
new SearchHit(-1, "10", Map.of(), Map.of()),
355+
new SearchHit(-1, "10", Map.of(), Map.of()),
356+
new SearchHit(-1, "10", Map.of(), Map.of()),
357+
new SearchHit(-1, "1", Map.of(), Map.of()),
358+
new SearchHit(-1, "2", Map.of(), Map.of()),
359+
new SearchHit(-1, "3", Map.of(), Map.of()) };
360+
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
361+
return searchHits;
362+
}
294363
}

0 commit comments

Comments
 (0)