Skip to content

Commit 6564cec

Browse files
Fixed wrapped bool queries for latest core, fixed failing tests
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
1 parent e238291 commit 6564cec

File tree

7 files changed

+49
-50
lines changed

7 files changed

+49
-50
lines changed

qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/BatchIngestionIT.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.nio.file.Path;
1212
import java.util.List;
1313
import java.util.Map;
14+
import java.util.Set;
1415

1516
import static org.opensearch.neuralsearch.util.BatchIngestionUtils.prepareDataForBulkIngestion;
1617
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
@@ -20,6 +21,7 @@ public class BatchIngestionIT extends AbstractRollingUpgradeTestCase {
2021
private static final String SPARSE_PIPELINE = "BatchIngestionIT_sparse_pipeline_rolling";
2122
private static final String TEXT_FIELD_NAME = "passage_text";
2223
private static final String EMBEDDING_FIELD_NAME = "passage_embedding";
24+
private static final Set<MLModelState> READY_FOR_INFERENCE_STATES = Set.of(MLModelState.LOADED, MLModelState.DEPLOYED);
2325

2426
public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exception {
2527
waitForClusterHealthGreen(NODES_BWC_CLUSTER, 90);
@@ -98,7 +100,7 @@ private void waitForModelToLoad(String modelId) throws Exception {
98100

99101
for (int attempt = 0; attempt < maxAttempts; attempt++) {
100102
MLModelState state = getModelState(modelId);
101-
if (state == MLModelState.LOADED) {
103+
if (READY_FOR_INFERENCE_STATES.contains(state)) {
102104
logger.info("Model {} is now loaded after {} attempts", modelId, attempt + 1);
103105
return;
104106
}

src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java

+23-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import java.util.Map;
1515
import java.util.Objects;
1616
import java.util.concurrent.Callable;
17+
import java.util.stream.Collectors;
1718

1819
import org.apache.lucene.search.BooleanClause;
1920
import org.apache.lucene.search.BooleanQuery;
@@ -42,6 +43,26 @@ public final class HybridQuery extends Query implements Iterable<Query> {
4243
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
4344
*/
4445
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final HybridQueryContext hybridQueryContext) {
46+
this(
47+
subQueries,
48+
hybridQueryContext,
49+
filterQueries == null
50+
? null
51+
: filterQueries.stream().map(query -> new BooleanClause(query, BooleanClause.Occur.FILTER)).collect(Collectors.toList())
52+
);
53+
}
54+
55+
/**
56+
* Create new instance of hybrid query object based on collection of sub queries and boolean clauses that are used as filters for each sub-query
57+
* @param subQueries
58+
* @param hybridQueryContext
59+
* @param booleanClauses
60+
*/
61+
public HybridQuery(
62+
final Collection<Query> subQueries,
63+
final HybridQueryContext hybridQueryContext,
64+
final List<BooleanClause> booleanClauses
65+
) {
4566
Objects.requireNonNull(subQueries, "collection of queries must not be null");
4667
if (subQueries.isEmpty()) {
4768
throw new IllegalArgumentException("collection of queries must not be empty");
@@ -50,14 +71,14 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
5071
if (Objects.nonNull(paginationDepth) && paginationDepth == 0) {
5172
throw new IllegalArgumentException("pagination_depth must not be zero");
5273
}
53-
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
74+
if (Objects.isNull(booleanClauses) || booleanClauses.isEmpty()) {
5475
this.subQueries = new ArrayList<>(subQueries);
5576
} else {
5677
List<Query> modifiedSubQueries = new ArrayList<>();
5778
for (Query subQuery : subQueries) {
5879
BooleanQuery.Builder builder = new BooleanQuery.Builder();
5980
builder.add(subQuery, BooleanClause.Occur.MUST);
60-
filterQueries.forEach(filterQuery -> builder.add(filterQuery, BooleanClause.Occur.FILTER));
81+
booleanClauses.forEach(filterQuery -> builder.add(booleanClauses));
6182
modifiedSubQueries.add(builder.build());
6283
}
6384
this.subQueries = modifiedSubQueries;

src/main/java/org/opensearch/neuralsearch/query/NeuralKNNQuery.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
package org.opensearch.neuralsearch.query;
66

77
import lombok.Getter;
8-
import org.apache.lucene.index.IndexReader;
98
import org.apache.lucene.search.IndexSearcher;
109
import org.apache.lucene.search.Query;
1110
import org.apache.lucene.search.QueryVisitor;
@@ -45,8 +44,8 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
4544
}
4645

4746
@Override
48-
public Query rewrite(IndexReader reader) throws IOException {
49-
Query rewritten = knnQuery.rewrite(reader);
47+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
48+
Query rewritten = knnQuery.rewrite(indexSearcher);
5049
if (rewritten == knnQuery) {
5150
return this;
5251
}

src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java

+4-7
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,11 @@ protected Query extractHybridQuery(final SearchContext searchContext, final Quer
7979
if (isHybridQueryWrappedInBooleanQuery(searchContext, query)) {
8080
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
8181
if (!(booleanClauses.get(0).query() instanceof HybridQuery)) {
82-
throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level query");
82+
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
8383
}
84-
HybridQuery hybridQuery = (HybridQuery) booleanClauses.stream().findFirst().get().query();
85-
List<Query> filterQueries = booleanClauses.stream()
86-
.filter(clause -> BooleanClause.Occur.FILTER == clause.occur())
87-
.map(BooleanClause::query)
88-
.collect(Collectors.toList());
89-
HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries, hybridQuery.getQueryContext());
84+
HybridQuery hybridQuery = (HybridQuery) booleanClauses.get(0).query();
85+
List<BooleanClause> filterQueries = booleanClauses.stream().skip(1).collect(Collectors.toList());
86+
HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), hybridQuery.getQueryContext(), filterQueries);
9087
return hybridQueryWithFilter;
9188
}
9289
return query;

src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java

+2-26
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,9 @@ public class HybridQueryUtil {
2424
* This method validates whether the query object is an instance of hybrid query
2525
*/
2626
public static boolean isHybridQuery(final Query query, final SearchContext searchContext) {
27-
if (query instanceof HybridQuery) {
27+
if (query instanceof HybridQuery
28+
|| (Objects.nonNull(searchContext.parsedQuery()) && searchContext.parsedQuery().query() instanceof HybridQuery)) {
2829
return true;
29-
} else if (isWrappedHybridQuery(query)) {
30-
/* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code
31-
https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370.
32-
main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks
33-
hybrid query for indexes with nested field types.
34-
in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for
35-
this search request.
36-
below is sample structure of such query:
37-
38-
Boolean {
39-
should: {
40-
hybrid: {
41-
sub_query1 {}
42-
sub_query2 {}
43-
}
44-
}
45-
filter: {
46-
exists: {
47-
field: "_primary_term"
48-
}
49-
}
50-
}
51-
*/
52-
// we have already checked if query in instance of Boolean in higher level else if condition
53-
return hasNestedFieldOrNestedDocs(query, searchContext) || hasAliasFilter(query, searchContext);
5430
}
5531
return false;
5632
}

src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java

+6-8
Original file line numberDiff line numberDiff line change
@@ -546,20 +546,18 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() {
546546
hit1DetailsForHit1DetailsForHit1DetailsForHit1.get("description")
547547
);
548548
assertTrue((double) hit1DetailsForHit1DetailsForHit1DetailsForHit1.get("value") > 0.0f);
549-
assertEquals(3, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").size());
549+
assertEquals(2, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").size());
550550

551-
assertEquals("boost", getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("description"));
552-
assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("value") > 0.0f);
553551
assertEquals(
554552
"idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:",
555-
getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("description")
553+
getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("description")
556554
);
557-
assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("value") > 0.0f);
555+
assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("value") > 0.0f);
558556
assertEquals(
559557
"tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:",
560-
getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(2).get("description")
558+
getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("description")
561559
);
562-
assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(2).get("value") > 0.0f);
560+
assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("value") > 0.0f);
563561

564562
// hit 4
565563
Map<String, Object> searchHit4 = nestedHits.get(3);
@@ -588,7 +586,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() {
588586
hit1DetailsForHit1DetailsForHit1DetailsForHit4.get("description")
589587
);
590588
assertTrue((double) hit1DetailsForHit1DetailsForHit1DetailsForHit4.get("value") > 0.0f);
591-
assertEquals(3, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit4, "details").size());
589+
assertEquals(2, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit4, "details").size());
592590

593591
// hit 6
594592
Map<String, Object> searchHit6 = nestedHits.get(5);

src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java

+9-3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.opensearch.index.mapper.MapperService;
5858
import org.opensearch.index.mapper.TextFieldMapper;
5959
import org.opensearch.index.query.BoolQueryBuilder;
60+
import org.opensearch.index.query.ParsedQuery;
6061
import org.opensearch.index.query.QueryBuilders;
6162
import org.opensearch.index.query.DisMaxQueryBuilder;
6263
import org.opensearch.index.query.QueryShardContext;
@@ -702,8 +703,8 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur
702703

703704
when(searchContext.query()).thenReturn(query);
704705

705-
IllegalStateException exception = expectThrows(
706-
IllegalStateException.class,
706+
IllegalArgumentException exception = expectThrows(
707+
IllegalArgumentException.class,
707708
() -> hybridQueryPhaseSearcher.searchWith(
708709
searchContext,
709710
contextIndexSearcher,
@@ -716,7 +717,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur
716717

717718
org.hamcrest.MatcherAssert.assertThat(
718719
exception.getMessage(),
719-
containsString("cannot process hybrid query due to incorrect structure of top level query")
720+
containsString("hybrid query must be a top level query and cannot be wrapped into other queries")
720721
);
721722

722723
releaseResources(directory, w, reader);
@@ -818,6 +819,8 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then
818819
Query query = builder.build();
819820

820821
when(searchContext.query()).thenReturn(query);
822+
Query hybridQuery = queryBuilder.toQuery(mockQueryShardContext);
823+
when(searchContext.parsedQuery()).thenReturn(new ParsedQuery(hybridQuery));
821824

822825
CollectorManager<? extends Collector, ReduceableSearchResult> collectorManager = HybridCollectorManager
823826
.createHybridCollectorManager(searchContext);
@@ -1110,6 +1113,9 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the
11101113
when(searchContext.query()).thenReturn(query);
11111114
when(searchContext.aliasFilter()).thenReturn(termFilter);
11121115

1116+
Query hybridQuery = queryBuilder.toQuery(mockQueryShardContext);
1117+
when(searchContext.parsedQuery()).thenReturn(new ParsedQuery(hybridQuery));
1118+
11131119
CollectorManager<? extends Collector, ReduceableSearchResult> collectorManager = HybridCollectorManager
11141120
.createHybridCollectorManager(searchContext);
11151121
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();

0 commit comments

Comments
 (0)