From 454bca04e2b0c0ac86d6f59e5a66a7092d2bdb91 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 16 Apr 2024 21:57:26 -0700 Subject: [PATCH 1/9] Add Advannce Post Filter Integ Tests Scenarios Signed-off-by: Varun Jain --- .../AbstractRestartUpgradeRestTestCase.java | 10 +- .../neuralsearch/bwc/HybridSearchIT.java | 12 +- .../neuralsearch/bwc/MultiModalSearchIT.java | 6 +- .../bwc/NeuralQueryEnricherProcessorIT.java | 8 +- .../bwc/NeuralSparseSearchIT.java | 8 +- .../neuralsearch/bwc/SemanticSearchIT.java | 6 +- .../bwc/AbstractRollingUpgradeTestCase.java | 16 +- .../neuralsearch/bwc/HybridSearchIT.java | 12 +- .../neuralsearch/bwc/MultiModalSearchIT.java | 6 +- .../bwc/NeuralQueryEnricherProcessorIT.java | 8 +- .../bwc/NeuralSparseSearchIT.java | 10 +- .../neuralsearch/bwc/SemanticSearchIT.java | 6 +- .../NeuralQueryEnricherProcessorIT.java | 6 +- .../processor/NormalizationProcessorIT.java | 8 +- .../NormalizationProcessorTests.java | 2 +- .../NormalizationProcessorWorkflowTests.java | 2 +- .../processor/ScoreCombinationIT.java | 16 +- .../ScoreCombinationTechniqueTests.java | 2 +- .../processor/ScoreNormalizationIT.java | 14 +- .../processor/TextChunkingProcessorIT.java | 2 +- .../rerank/MLOpenSearchRerankProcessorIT.java | 2 +- .../query/HybridQueryAggregationsIT.java | 37 +- .../query/HybridQueryBuilderTests.java | 2 +- .../neuralsearch/query/HybridQueryIT.java | 10 +- .../query/HybridQueryPostFilterIT.java | 418 ++++++++++++++++++ .../query/NeuralQueryBuilderTests.java | 2 +- .../neuralsearch/query/NeuralQueryIT.java | 10 +- .../query/NeuralSparseQueryBuilderTests.java | 2 +- .../query/NeuralSparseQueryIT.java | 4 +- .../query/HybridQueryPhaseSearcherTests.java | 8 +- .../neuralsearch/BaseNeuralSearchIT.java | 16 +- .../OpenSearchSecureRestTestCase.java | 10 +- .../util/AggregationsTestUtils.java | 0 .../util/NeuralSearchClusterTestUtils.java | 0 .../neuralsearch/{ => util}/TestUtils.java | 35 +- 35 files changed, 569 insertions(+), 147 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java rename src/{test => testFixtures}/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java (100%) rename src/{test => testFixtures}/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java (100%) rename src/testFixtures/java/org/opensearch/neuralsearch/{ => util}/TestUtils.java (90%) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java index c2d2657f4..a3cfd4a04 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java @@ -11,11 +11,11 @@ import org.junit.Before; import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.NEURAL_SEARCH_BWC_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.CLIENT_TIMEOUT_VALUE; -import static org.opensearch.neuralsearch.TestUtils.RESTART_UPGRADE_OLD_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.BWC_VERSION; -import static org.opensearch.neuralsearch.TestUtils.generateModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.CLIENT_TIMEOUT_VALUE; +import static org.opensearch.neuralsearch.util.TestUtils.RESTART_UPGRADE_OLD_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.BWC_VERSION; +import static org.opensearch.neuralsearch.util.TestUtils.generateModelId; import org.opensearch.test.rest.OpenSearchRestTestCase; public abstract class AbstractRestartUpgradeRestTestCase extends BaseNeuralSearchIT { diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 48735182a..f5289fe79 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -11,12 +11,12 @@ import java.util.List; import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.getModelId; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index e6749d778..1d9dde2c6 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class MultiModalSearchIT extends AbstractRestartUpgradeRestTestCase { diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java index 876b2b0d7..02edb486c 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java @@ -4,12 +4,12 @@ */ package org.opensearch.neuralsearch.bwc; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; import org.opensearch.common.settings.Settings; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java index 22bd4a281..8ec54711a 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java @@ -10,10 +10,10 @@ import java.util.Map; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.neuralsearch.TestUtils; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; +import org.opensearch.neuralsearch.util.TestUtils; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; public class NeuralSparseSearchIT extends AbstractRestartUpgradeRestTestCase { diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java index ec5938cd9..27ca7f42d 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.getModelId; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class SemanticSearchIT extends AbstractRestartUpgradeRestTestCase { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java index 16ed2d229..a3ad530dc 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java @@ -11,14 +11,14 @@ import org.junit.Before; import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.NEURAL_SEARCH_BWC_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.OLD_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.MIXED_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.UPGRADED_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.ROLLING_UPGRADE_FIRST_ROUND; -import static org.opensearch.neuralsearch.TestUtils.BWCSUITE_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.BWC_VERSION; -import static org.opensearch.neuralsearch.TestUtils.generateModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.OLD_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.MIXED_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.UPGRADED_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.ROLLING_UPGRADE_FIRST_ROUND; +import static org.opensearch.neuralsearch.util.TestUtils.BWCSUITE_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.BWC_VERSION; +import static org.opensearch.neuralsearch.util.TestUtils.generateModelId; import org.opensearch.test.rest.OpenSearchRestTestCase; public abstract class AbstractRollingUpgradeTestCase extends BaseNeuralSearchIT { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 292540820..903ffc9be 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -10,12 +10,12 @@ import java.util.List; import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index b91ec1322..e10ddd17e 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class MultiModalSearchIT extends AbstractRollingUpgradeTestCase { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java index 281c78821..c9897447e 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.bwc; import org.opensearch.common.settings.Settings; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; @@ -13,9 +13,9 @@ import java.nio.file.Path; import java.util.List; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; public class NeuralQueryEnricherProcessorIT extends AbstractRollingUpgradeTestCase { // add prefix to avoid conflicts with other IT class, since we don't wipe resources after first round diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java index 70513686b..e461508e8 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java @@ -10,11 +10,11 @@ import java.util.Map; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.neuralsearch.TestUtils; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import org.opensearch.neuralsearch.util.TestUtils; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; public class NeuralSparseSearchIT extends AbstractRollingUpgradeTestCase { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java index 51e548474..b9f7b15a9 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class SemanticSearchIT extends AbstractRollingUpgradeTestCase { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java index 0f4c49f27..baee337ce 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java @@ -4,9 +4,9 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.util.Collections; import java.util.List; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 750278ca3..05eb6829c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.ArrayList; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index dd185e227..7c443a825 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -39,7 +39,7 @@ import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 2f880ce74..5d88ffed9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -21,7 +21,7 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index b3478984c..e1360474c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -6,11 +6,11 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; -import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.util.TestUtils.assertWeightedScores; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.Arrays; @@ -24,9 +24,9 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import com.google.common.primitives.Floats; import lombok.SneakyThrows; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index 4f76c666e..da9b34f22 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -4,7 +4,7 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import java.util.List; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 175ea08fe..ff1a2001c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.Arrays; @@ -20,9 +20,9 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import com.google.common.primitives.Floats; import lombok.SneakyThrows; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java index dd517aa17..42f37d01a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java @@ -24,7 +24,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_USER_AGENT; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; public class TextChunkingProcessorIT extends BaseNeuralSearchIT { private static final String INDEX_NAME = "text_chunking_test_index"; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index f7dd7b647..fcb946d84 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -24,7 +24,7 @@ import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_USER_AGENT; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; @Log4j2 public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java index 4647ebf5f..2a2fc7f34 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -27,25 +27,21 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.stream.IntStream; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; -import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; -import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; +import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQuery; /** * Integration tests for base scenarios when aggregations are combined with hybrid query */ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT { - private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = - "test-neural-aggs-pipeline-multi-doc-index-multiple-shards"; - private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-neural-aggs-multi-doc-index-single-shard"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = "test-hybrid-aggs-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-hybrid-aggs-multi-doc-index-single-shard"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "everyone"; private static final String TEST_QUERY_TEXT5 = "welcome"; @@ -53,7 +49,7 @@ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT { private static final String TEST_DOC_TEXT2 = "Hi to this place"; private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; - private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String SEARCH_PIPELINE = "phase-results-hybrid-aggregation-pipeline"; private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; private static final String TEST_DOC_TEXT5 = "People keep telling me orange but I still prefer pink"; private static final String TEST_DOC_TEXT6 = "She traveled because it cost the same as therapy and was a lot more enjoyable"; @@ -786,29 +782,6 @@ private Map executeQueryAndGetAggsResults( return searchResponseAsMap; } - private void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { - assertEquals(expected, getHitCount(searchResponseAsMap)); - - List> hits1NestedList = getNestedHits(searchResponseAsMap); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hits1NestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - - // verify that scores are in desc order - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - // verify that all ids are unique - assertEquals(Set.copyOf(ids).size(), ids.size()); - - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(expected, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - } - private HybridQueryBuilder createHybridQueryBuilder(boolean isComplex) { if (isComplex) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 7beb02dcc..8ff552698 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -12,7 +12,7 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; -import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_TEXT_FIELD; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index b198a51ee..be6942232 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -7,11 +7,11 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.opensearch.index.query.QueryBuilders.matchQuery; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; -import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.ArrayList; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java new file mode 100644 index 000000000..ac4eaab35 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -0,0 +1,418 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.MatchNoneQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQuery; + +public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = + "test-hybrid-post-filter-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = + "test-hybrid-post-filter-multi-doc-index-single-shard"; + private static final String SEARCH_PIPELINE = "phase-results-hybrid-post-filter-pipeline"; + private static final String INTEGER_FIELD_1 = "stock"; + private static final String TEXT_FIELD_1 = "name"; + private static final String KEYWORD_FIELD_2 = "category"; + private static final String TEXT_FIELD_NAME_1_VALUE = "Dunes part 2"; + private static final String TEXT_FIELD_NAME_2_VALUE = "Dunes part 1"; + private static final String TEXT_FIELD_NAME_3_VALUE = "Mission Impossible 1"; + private static final String TEXT_FIELD_NAME_4_VALUE = "Mission Impossible 2"; + private static final String TEXT_FIELD_NAME_5_VALUE = "The Terminal"; + private static final String TEXT_FIELD_NAME_6_VALUE = "Avengers"; + private static final int INTEGER_FIELD_STOCK_1_VALUE = 25; + private static final int INTEGER_FIELD_STOCK_2_VALUE = 22; + private static final int INTEGER_FIELD_STOCK_3_VALUE = 256; + private static final int INTEGER_FIELD_STOCK_4_VALUE = 25; + private static final int INTEGER_FIELD_STOCK_5_VALUE = 20; + private static final String KEYWORD_FIELD_CATEGORY_1_VALUE = "Drama"; + private static final String KEYWORD_FIELD_CATEGORY_2_VALUE = "Action"; + private static final String KEYWORD_FIELD_CATEGORY_3_VALUE = "Sci-fi"; + private static final String AVG_AGGREGATION_NAME = "avg_stock_size"; + private static boolean setUpIsDone = false; + + @Before + public void setUp() throws Exception { + super.setUp(); + if (setUpIsDone) { + return; + } + updateClusterSettings(); + setUpIsDone = true; + } + + @SneakyThrows + public void testPostFilterOnIndexWithSingleShard_WhenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + prepareResourcesBeforeTestExecution(1); + testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPostFilterOnIndexWithSingleShard_WhenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + prepareResourcesBeforeTestExecution(1); + testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + prepareResourcesBeforeTestExecution(3); + testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + prepareResourcesBeforeTestExecution(3); + testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderScenario1("mission", "part", 400, 200); + QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithRangeQuery(400, 230); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + testResults(searchResponseAsMap, 1, 0, 230, 400); + } + + @SneakyThrows + private void testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(String indexName) { + // Case 1 + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderScenario1("mission", "part", 400, 200); + QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithBoolShouldQuery("impossible", 400, 230); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + testResults(searchResponseAsMap, 2, 1, 230, 400); + // Case 2 + AggregationBuilder aggsBuilder = createAggregations(); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + List.of(aggsBuilder), + postFilterQuery + ); + testResults(searchResponseAsMap, 2, 1, 230, 400); + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map aggValue = getAggregationValues(aggregations, AVG_AGGREGATION_NAME); + assertEquals(1, aggValue.size()); + // Case 3 + postFilterQuery = createPostFilterQueryBuilderWithBoolMustQuery("terminal", 400, 230); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + testResults(searchResponseAsMap, 0, 0, 230, 400); + // Case 4 + hybridQueryBuilder = createHybridQueryBuilderScenario2("hero", 5000, 1000); + postFilterQuery = createPostFilterQueryBuilderWithBoolShouldQuery("impossible", 400, 230); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + testResults(searchResponseAsMap, 0, 0, 230, 400); + } + + @SneakyThrows + private void testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderScenario1("mission", "part", 400, 200); + QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(true); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + testResults(searchResponseAsMap, 4, 3, 230, 400); + + postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(false); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + testResults(searchResponseAsMap, 0, 0, 230, 400); + } + + private void testResults( + Map searchResponseAsMap, + int resultsExpected, + int postFilterResultsValidationExpected, + int lte, + int gte + ) { + assertHitResultsFromQuery(resultsExpected, searchResponseAsMap); + List> hitsNestedList = getNestedHits(searchResponseAsMap); + + List docIndexes = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + Map source = (Map) oneHit.get("_source"); + int docIndex = (int) source.get(INTEGER_FIELD_1); + docIndexes.add(docIndex); + } + assertEquals(postFilterResultsValidationExpected, docIndexes.stream().filter(docIndex -> docIndex < lte || docIndex > gte).count()); + } + + @SneakyThrows + void prepareResourcesBeforeTestExecution(int numShards) { + if (numShards == 1) { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, numShards); + } else { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, numShards); + } + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + } + + @SneakyThrows + private void initializeIndexIfNotExists(String indexName, int numShards) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_2), List.of(), numShards), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1), + Collections.singletonList(TEXT_FIELD_NAME_1_VALUE), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_STOCK_1_VALUE), + List.of(KEYWORD_FIELD_2), + List.of(KEYWORD_FIELD_CATEGORY_1_VALUE), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1), + Collections.singletonList(TEXT_FIELD_NAME_2_VALUE), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_STOCK_2_VALUE), + List.of(KEYWORD_FIELD_2), + List.of(KEYWORD_FIELD_CATEGORY_1_VALUE), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "3", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1), + Collections.singletonList(TEXT_FIELD_NAME_3_VALUE), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_STOCK_3_VALUE), + List.of(KEYWORD_FIELD_2), + List.of(KEYWORD_FIELD_CATEGORY_2_VALUE), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1), + Collections.singletonList(TEXT_FIELD_NAME_4_VALUE), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_STOCK_4_VALUE), + List.of(KEYWORD_FIELD_2), + List.of(KEYWORD_FIELD_CATEGORY_2_VALUE), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1), + Collections.singletonList(TEXT_FIELD_NAME_5_VALUE), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_STOCK_5_VALUE), + List.of(KEYWORD_FIELD_2), + List.of(KEYWORD_FIELD_CATEGORY_1_VALUE), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1), + Collections.singletonList(TEXT_FIELD_NAME_6_VALUE), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_STOCK_5_VALUE), + List.of(KEYWORD_FIELD_2), + List.of(KEYWORD_FIELD_CATEGORY_3_VALUE), + List.of(), + List.of() + ); + } + } + + private HybridQueryBuilder createHybridQueryBuilderScenario1(String text, String value, int lte, int gte) { + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1, text); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1, value); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder).add(termQueryBuilder).add(rangeQueryBuilder); + return hybridQueryBuilder; + } + + private HybridQueryBuilder createHybridQueryBuilderScenario2(String text, int lte, int gte) { + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1, text); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder).add(rangeQueryBuilder); + return hybridQueryBuilder; + } + + private QueryBuilder createPostFilterQueryBuilderWithRangeQuery(int lte, int gte) { + return QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + } + + private QueryBuilder createPostFilterQueryBuilderWithBoolShouldQuery(String query, int lte, int gte) { + QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1, query); + return QueryBuilders.boolQuery().should(rangeQuery).should(matchQuery); + } + + private QueryBuilder createPostFilterQueryBuilderWithBoolMustQuery(String query, int lte, int gte) { + QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1, query); + return QueryBuilders.boolQuery().must(rangeQuery).must(matchQuery); + } + + private QueryBuilder createPostFilterQueryBuilderWithMatchAllOrNoneQuery(boolean isMatchAll) { + if (isMatchAll) { + return QueryBuilders.matchAllQuery(); + } + + MatchNoneQueryBuilder matchNoneQueryBuilder = new MatchNoneQueryBuilder(); + return new MatchNoneQueryBuilder(); + } + + private AggregationBuilder createAggregations() { + return AggregationBuilders.avg(AVG_AGGREGATION_NAME).field(INTEGER_FIELD_1); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index dd63abbea..1fa7e94c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -12,7 +12,7 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; -import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index 9cc9dda71..2e4c766aa 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -4,11 +4,11 @@ */ package org.opensearch.neuralsearch.query; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; import java.util.Collections; import java.util.List; diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 89bcd57d7..bc1203f0e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -9,7 +9,7 @@ import static org.mockito.Mockito.mock; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; -import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD; diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java index d43d252b9..418e64801 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.query; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; import java.util.List; import java.util.Map; @@ -15,7 +15,7 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import lombok.SneakyThrows; diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e489b3c4a..88bf01dec 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -462,7 +462,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -570,7 +570,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -638,7 +638,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -781,7 +781,7 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 1c21e5f5e..baecf2932 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -56,14 +56,14 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import com.google.common.collect.ImmutableList; -import static org.opensearch.neuralsearch.TestUtils.MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_USER_AGENT; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.TestUtils.MAX_RETRY; -import static org.opensearch.neuralsearch.TestUtils.MAX_TIME_OUT_INTERVAL; +import static org.opensearch.neuralsearch.util.TestUtils.MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.MAX_RETRY; +import static org.opensearch.neuralsearch.util.TestUtils.MAX_TIME_OUT_INTERVAL; import lombok.AllArgsConstructor; import lombok.Getter; diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java b/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java index 133f42daf..a43f77917 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java @@ -7,11 +7,11 @@ import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.neuralsearch.TestUtils.NEURAL_SEARCH_BWC_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.OPENDISTRO_SECURITY; -import static org.opensearch.neuralsearch.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.SECURITY_AUDITLOG_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.SKIP_DELETE_MODEL_INDEX; +import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.OPENDISTRO_SECURITY; +import static org.opensearch.neuralsearch.util.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.SECURITY_AUDITLOG_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.SKIP_DELETE_MODEL_INDEX; import java.io.IOException; import java.util.Collections; diff --git a/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java similarity index 100% rename from src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java rename to src/testFixtures/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java diff --git a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java similarity index 100% rename from src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java rename to src/testFixtures/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java similarity index 90% rename from src/testFixtures/java/org/opensearch/neuralsearch/TestUtils.java rename to src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index a6f4a3e0f..2a37c165e 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -2,13 +2,15 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch; +package org.opensearch.neuralsearch.util; import com.carrotsearch.randomizedtesting.RandomizedTest; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; import static org.opensearch.test.OpenSearchTestCase.randomFloat; import java.util.ArrayList; @@ -299,6 +301,29 @@ public static void assertFetchResultScores(FetchSearchResult fetchSearchResult, assertEquals(0.001f, minScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); } + public static void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { + assertEquals(expected, getHitCount(searchResponseAsMap)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(expected, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + private static List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits"); @@ -314,6 +339,13 @@ private static Optional getMaxScore(Map searchResponseAsM return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); } + @SuppressWarnings("unchecked") + private static int getHitCount(final Map searchResponseAsMap) { + Map hits1map = (Map) searchResponseAsMap.get("hits"); + List hits1List = (List) hits1map.get("hits"); + return hits1List.size(); + } + public static String getModelId(Map pipeline, String processor) { assertNotNull(pipeline); ArrayList> processors = (ArrayList>) pipeline.get("processors"); @@ -326,5 +358,4 @@ public static String getModelId(Map pipeline, String processor) public static String generateModelId() { return "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8); } - } From eacd0f1aba9401fccd0bb6fed9462d1d2bbcb7ff Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 16 Apr 2024 22:22:06 -0700 Subject: [PATCH 2/9] Reverting comment in HybridQueryPhaseSearcher Signed-off-by: Varun Jain --- .../search/query/HybridQueryPhaseSearcherTests.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 88bf01dec..e489b3c4a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -462,7 +462,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -570,7 +570,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -638,7 +638,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -781,7 +781,7 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); - // when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); + when(indexMetadata.getCustomData(eq(IndexMetadata.REMOTE_STORE_CUSTOM_KEY))).thenReturn(null); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); From 4b4f3bc545f946ddfdec0add7dbdf3a73a260cb5 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Thu, 18 Apr 2024 13:57:40 -0700 Subject: [PATCH 3/9] Adding Martin Comments Signed-off-by: Varun Jain --- .../query/HybridQueryPostFilterIT.java | 445 +++++++++++++++--- .../neuralsearch/util/TestUtils.java | 8 +- 2 files changed, 387 insertions(+), 66 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index dccbb778d..5db8db446 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -33,12 +33,12 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final String INTEGER_FIELD_1 = "stock"; private static final String TEXT_FIELD_1 = "name"; private static final String KEYWORD_FIELD_2 = "category"; - private static final String TEXT_FIELD_NAME_1_VALUE = "Dunes part 2"; - private static final String TEXT_FIELD_NAME_2_VALUE = "Dunes part 1"; - private static final String TEXT_FIELD_NAME_3_VALUE = "Mission Impossible 1"; - private static final String TEXT_FIELD_NAME_4_VALUE = "Mission Impossible 2"; - private static final String TEXT_FIELD_NAME_5_VALUE = "The Terminal"; - private static final String TEXT_FIELD_NAME_6_VALUE = "Avengers"; + private static final String TEXT_FIELD_VALUE_1 = "Dunes part 2"; + private static final String TEXT_FIELD_VALUE_2 = "Dunes part 1"; + private static final String TEXT_FIELD_VALUE_3 = "Mission Impossible 1"; + private static final String TEXT_FIELD_VALUE_4 = "Mission Impossible 2"; + private static final String TEXT_FIELD_VALUE_5 = "The Terminal"; + private static final String TEXT_FIELD_VALUE_6 = "Avengers"; private static final int INTEGER_FIELD_STOCK_1_VALUE = 25; private static final int INTEGER_FIELD_STOCK_2_VALUE = 22; private static final int INTEGER_FIELD_STOCK_3_VALUE = 256; @@ -49,6 +49,12 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final String KEYWORD_FIELD_CATEGORY_3_VALUE = "Sci-fi"; private static final String AVG_AGGREGATION_NAME = "avg_stock_size"; private static boolean setUpIsDone = false; + private static final int shards_count_in_single_node_cluster = 1; + private static final int shards_count_in_multi_node_cluster = 3; + private static final int lte_of_range_in_hybrid_query = 400; + private static final int gte_of_range_in_hybrid_query = 200; + private static final int lte_of_range_in_post_filter_query = 400; + private static final int gte_of_range_in_post_filter_query = 230; @Before public void setUp() throws Exception { @@ -61,69 +67,106 @@ public void setUp() throws Exception { } @SneakyThrows - public void testPostFilterOnIndexWithSingleShard_WhenConcurrentSearchEnabled_thenSuccessful() { + public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", true); - prepareResourcesBeforeTestExecution(1); - testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); - testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); - testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( - TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD - ); + prepareResourcesBeforeTestExecution(shards_count_in_single_node_cluster); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); } } @SneakyThrows - public void testPostFilterOnIndexWithSingleShard_WhenConcurrentSearchDisabled_thenSuccessful() { + public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", false); - prepareResourcesBeforeTestExecution(1); - testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); - testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); - testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( - TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD - ); + prepareResourcesBeforeTestExecution(shards_count_in_single_node_cluster); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); } } @SneakyThrows - public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() { + public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", true); - prepareResourcesBeforeTestExecution(3); - testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); - testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); - testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( - TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS - ); + prepareResourcesBeforeTestExecution(shards_count_in_multi_node_cluster); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); } } @SneakyThrows - public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchDisabled_thenSuccessful() { + public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchDisabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", false); - prepareResourcesBeforeTestExecution(3); - testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); - testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); - testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful( - TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS - ); + prepareResourcesBeforeTestExecution(shards_count_in_multi_node_cluster); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); } } @SneakyThrows - private void testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessful(String indexName) { - HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery("mission", "part", 400, 200); - QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithRangeQuery(400, 230); + private void testPostFilterRangeQuery(String indexName) { + /*{ + "query": { + "hybrid":{ + "queries":[ + { + "match":{ + "name": "mission" + } + }, + { + "term":{ + "name":{ + "value":"part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + + }, + "post_filter":{ + "range": { + "stock": { + "gte": 230, + "lte": 400 + } + } + } + }*/ + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + lte_of_range_in_hybrid_query, + gte_of_range_in_hybrid_query + ); + QueryBuilder postFilterQuery = createQueryBuilderWithRangeQuery( + lte_of_range_in_post_filter_query, + gte_of_range_in_post_filter_query + ); Map searchResponseAsMap = search( indexName, @@ -134,14 +177,72 @@ private void testPostFilterRangeQuery_WhenMatchTermAndRangeQueries_thenSuccessfu null, postFilterQuery ); - testResults(searchResponseAsMap, 1, 0, 230, 400); + assertHybridQueryResults(searchResponseAsMap, 1, 0, gte_of_range_in_post_filter_query, lte_of_range_in_hybrid_query); } @SneakyThrows - private void testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful(String indexName) { + private void testPostFilterBoolQuery(String indexName) { // Case 1 - HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery("mission", "part", 400, 200); - QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithBoolShouldQuery("impossible", 400, 230); + /*{ + "query": { + "hybrid":{ + "queries":[ + { + "match":{ + "name": "mission" + } + }, + { + "term":{ + "name":{ + "value":"part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + + }, + "post_filter":{ + "bool":{ + "should":[ + { + "range": { + "stock": { + "gte": 230, + "lte": 400 + } + } + }, + { + "match":{ + "name":"impossible" + } + } + + ] + } + } + }*/ + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + lte_of_range_in_hybrid_query, + gte_of_range_in_hybrid_query + ); + QueryBuilder postFilterQuery = createQueryBuilderWithBoolShouldQuery( + "impossible", + lte_of_range_in_post_filter_query, + gte_of_range_in_post_filter_query + ); Map searchResponseAsMap = search( indexName, @@ -152,8 +253,62 @@ private void testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful null, postFilterQuery ); - testResults(searchResponseAsMap, 2, 1, 230, 400); + assertHybridQueryResults(searchResponseAsMap, 2, 1, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); // Case 2 + /*{ + "query": { + "hybrid":{ + "queries":[ + { + "match":{ + "name": "mission" + } + }, + { + "term":{ + "name":{ + "value":"part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + + }, + "aggs": { + "avg_stock_size": { + "avg": { "field": "stock" } + } + }, + "post_filter":{ + "bool":{ + "should":[ + { + "range": { + "stock": { + "gte": 230, + "lte": 400 + } + } + }, + { + "match":{ + "name":"impossible" + } + } + + ] + } + } + }*/ AggregationBuilder aggsBuilder = createAggregations(); searchResponseAsMap = search( indexName, @@ -164,14 +319,67 @@ private void testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful List.of(aggsBuilder), postFilterQuery ); - testResults(searchResponseAsMap, 2, 1, 230, 400); + assertHybridQueryResults(searchResponseAsMap, 2, 1, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); Map aggregations = getAggregations(searchResponseAsMap); assertNotNull(aggregations); Map aggValue = getAggregationValues(aggregations, AVG_AGGREGATION_NAME); assertEquals(1, aggValue.size()); // Case 3 - postFilterQuery = createPostFilterQueryBuilderWithBoolMustQuery("terminal", 400, 230); + /*{ + "query": { + "hybrid":{ + "queries":[ + { + "match":{ + "name": "mission" + } + }, + { + "term":{ + "name":{ + "value":"part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + + }, + "post_filter":{ + "bool":{ + "must":[ + { + "range": { + "stock": { + "gte": 230, + "lte": 400 + } + } + }, + { + "match":{ + "name":"terminal" + } + } + + ] + } + } + }*/ + postFilterQuery = createQueryBuilderWithBoolMustQuery( + "terminal", + lte_of_range_in_post_filter_query, + gte_of_range_in_post_filter_query + ); searchResponseAsMap = search( indexName, hybridQueryBuilder, @@ -181,10 +389,56 @@ private void testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful null, postFilterQuery ); - testResults(searchResponseAsMap, 0, 0, 230, 400); + assertHybridQueryResults(searchResponseAsMap, 0, 0, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); // Case 4 + /*{ + "query": { + "hybrid":{ + "queries":[ + { + "match":{ + "name": "hero" + } + }, + { + "range": { + "stock": { + "gte": 1000, + "lte": 5000 + } + } + } + ] + } + + }, + "post_filter":{ + "bool":{ + "should":[ + { + "range": { + "stock": { + "gte": 230, + "lte": 400 + } + } + }, + { + "match":{ + "name":"impossible" + } + } + + ] + } + } + }*/ hybridQueryBuilder = createHybridQueryBuilderScenarioWithMatchAndRangeQuery("hero", 5000, 1000); - postFilterQuery = createPostFilterQueryBuilderWithBoolShouldQuery("impossible", 400, 230); + postFilterQuery = createQueryBuilderWithBoolShouldQuery( + "impossible", + lte_of_range_in_post_filter_query, + gte_of_range_in_post_filter_query + ); searchResponseAsMap = search( indexName, hybridQueryBuilder, @@ -194,12 +448,48 @@ private void testPostFilterBoolQuery_WhenMatchTermAndRangeQueries_thenSuccessful null, postFilterQuery ); - testResults(searchResponseAsMap, 0, 0, 230, 400); + assertHybridQueryResults(searchResponseAsMap, 0, 0, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); } @SneakyThrows - private void testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_thenSuccessful(String indexName) { - HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery("mission", "part", 400, 200); + private void testPostFilterMatchAllAndNoneQuery(String indexName) { + /*{ + "query": { + "hybrid": { + "queries": [ + { + "match": { + "name": "mission" + } + }, + { + "term": { + "name": { + "value": "part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + }, + "post_filter": { + "match_all": {} + } + }*/ + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + gte_of_range_in_hybrid_query, + lte_of_range_in_hybrid_query + ); QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(true); Map searchResponseAsMap = search( @@ -211,8 +501,39 @@ private void testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_the null, postFilterQuery ); - testResults(searchResponseAsMap, 4, 3, 230, 400); - + assertHybridQueryResults(searchResponseAsMap, 4, 3, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); + + /*{ + "query": { + "hybrid": { + "queries": [ + { + "match": { + "name": "mission" + } + }, + { + "term": { + "name": { + "value": "part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + }, + "post_filter": { + "match_none": {} + } + }*/ postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(false); searchResponseAsMap = search( indexName, @@ -223,10 +544,10 @@ private void testPostFilterMatchAllAndNoneQuery_WhenMatchTermAndRangeQueries_the null, postFilterQuery ); - testResults(searchResponseAsMap, 0, 0, 230, 400); + assertHybridQueryResults(searchResponseAsMap, 0, 0, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); } - private void testResults( + private void assertHybridQueryResults( Map searchResponseAsMap, int resultsExpected, int postFilterResultsValidationExpected, @@ -271,7 +592,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_NAME_1_VALUE), + Collections.singletonList(TEXT_FIELD_VALUE_1), List.of(), List.of(), List.of(INTEGER_FIELD_1), @@ -288,7 +609,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_NAME_2_VALUE), + Collections.singletonList(TEXT_FIELD_VALUE_2), List.of(), List.of(), List.of(INTEGER_FIELD_1), @@ -305,7 +626,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_NAME_3_VALUE), + Collections.singletonList(TEXT_FIELD_VALUE_3), List.of(), List.of(), List.of(INTEGER_FIELD_1), @@ -322,7 +643,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_NAME_4_VALUE), + Collections.singletonList(TEXT_FIELD_VALUE_4), List.of(), List.of(), List.of(INTEGER_FIELD_1), @@ -339,7 +660,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_NAME_5_VALUE), + Collections.singletonList(TEXT_FIELD_VALUE_5), List.of(), List.of(), List.of(INTEGER_FIELD_1), @@ -356,7 +677,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_NAME_6_VALUE), + Collections.singletonList(TEXT_FIELD_VALUE_6), List.of(), List.of(), List.of(INTEGER_FIELD_1), @@ -386,17 +707,17 @@ private HybridQueryBuilder createHybridQueryBuilderScenarioWithMatchAndRangeQuer return hybridQueryBuilder; } - private QueryBuilder createPostFilterQueryBuilderWithRangeQuery(int lte, int gte) { + private QueryBuilder createQueryBuilderWithRangeQuery(int lte, int gte) { return QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); } - private QueryBuilder createPostFilterQueryBuilderWithBoolShouldQuery(String query, int lte, int gte) { + private QueryBuilder createQueryBuilderWithBoolShouldQuery(String query, int lte, int gte) { QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1, query); return QueryBuilders.boolQuery().should(rangeQuery).should(matchQuery); } - private QueryBuilder createPostFilterQueryBuilderWithBoolMustQuery(String query, int lte, int gte) { + private QueryBuilder createQueryBuilderWithBoolMustQuery(String query, int lte, int gte) { QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1, query); return QueryBuilders.boolQuery().must(rangeQuery).must(matchQuery); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index ab1a192cb..0534f85bf 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -305,10 +305,10 @@ public static void assertFetchResultScores(FetchSearchResult fetchSearchResult, public static void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { assertEquals(expected, getHitCount(searchResponseAsMap)); - List> hits1NestedList = getNestedHits(searchResponseAsMap); + List> hitsNestedList = getNestedHits(searchResponseAsMap); List ids = new ArrayList<>(); List scores = new ArrayList<>(); - for (Map oneHit : hits1NestedList) { + for (Map oneHit : hitsNestedList) { ids.add((String) oneHit.get("_id")); scores.add((Double) oneHit.get("_score")); } @@ -342,8 +342,8 @@ private static Optional getMaxScore(Map searchResponseAsM @SuppressWarnings("unchecked") private static int getHitCount(final Map searchResponseAsMap) { - Map hitsmap = (Map) searchResponseAsMap.get("hits"); - List hitsList = (List) hitsmap.get("hits"); + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + List hitsList = (List) hitsMap.get("hits"); return hitsList.size(); } From 27cc05d82c66bc74765af6cab4c40bfa5441ed23 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Thu, 18 Apr 2024 14:02:34 -0700 Subject: [PATCH 4/9] Adding Martin Comments Signed-off-by: Varun Jain --- .../opensearch/neuralsearch/query/HybridQueryPostFilterIT.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index 5db8db446..2a5df3ed1 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -56,6 +56,8 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final int lte_of_range_in_post_filter_query = 400; private static final int gte_of_range_in_post_filter_query = 230; + // @Before is a workaround to save extra update cluster settings call to the cluster. + // @BeforeClass throws RuntimeException with initializationError @Before public void setUp() throws Exception { super.setUp(); From 4987f774283b76d3a964001d87388e5f669c9173 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 19 Apr 2024 23:42:58 -0700 Subject: [PATCH 5/9] Addressing Martin Comments Signed-off-by: Varun Jain --- .../query/HybridQueryPostFilterIT.java | 83 +++++++++---------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index 2a5df3ed1..0b7a90b25 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -49,12 +49,12 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final String KEYWORD_FIELD_CATEGORY_3_VALUE = "Sci-fi"; private static final String AVG_AGGREGATION_NAME = "avg_stock_size"; private static boolean setUpIsDone = false; - private static final int shards_count_in_single_node_cluster = 1; - private static final int shards_count_in_multi_node_cluster = 3; - private static final int lte_of_range_in_hybrid_query = 400; - private static final int gte_of_range_in_hybrid_query = 200; - private static final int lte_of_range_in_post_filter_query = 400; - private static final int gte_of_range_in_post_filter_query = 230; + private static final int SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER = 1; + private static final int SHARDS_COUNT_IN_MULTI_NODE_CLUSTER = 3; + private static final int LTE_OF_RANGE_IN_HYBRID_QUERY = 400; + private static final int GTE_OF_RANGE_IN_HYBRID_QUERY = 200; + private static final int LTE_OF_RANGE_IN_POST_FILTER_QUERY = 400; + private static final int GTE_OF_RANGE_IN_POST_FILTER_QUERY = 230; // @Before is a workaround to save extra update cluster settings call to the cluster. // @BeforeClass throws RuntimeException with initializationError @@ -72,10 +72,10 @@ public void setUp() throws Exception { public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", true); - prepareResourcesBeforeTestExecution(shards_count_in_single_node_cluster); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); - testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); } @@ -85,10 +85,10 @@ public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_the public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", false); - prepareResourcesBeforeTestExecution(shards_count_in_single_node_cluster); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); - testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); } @@ -98,10 +98,10 @@ public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_th public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", true); - prepareResourcesBeforeTestExecution(shards_count_in_multi_node_cluster); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); - testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); } @@ -111,10 +111,10 @@ public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_ public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchDisabled_thenSuccessful() { try { updateClusterSettings("search.concurrent_segment_search.enabled", false); - prepareResourcesBeforeTestExecution(shards_count_in_multi_node_cluster); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); - testPostFilterMatchAllAndNoneQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); } @@ -162,12 +162,12 @@ private void testPostFilterRangeQuery(String indexName) { HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( "mission", "part", - lte_of_range_in_hybrid_query, - gte_of_range_in_hybrid_query + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY ); QueryBuilder postFilterQuery = createQueryBuilderWithRangeQuery( - lte_of_range_in_post_filter_query, - gte_of_range_in_post_filter_query + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY ); Map searchResponseAsMap = search( @@ -179,7 +179,7 @@ private void testPostFilterRangeQuery(String indexName) { null, postFilterQuery ); - assertHybridQueryResults(searchResponseAsMap, 1, 0, gte_of_range_in_post_filter_query, lte_of_range_in_hybrid_query); + assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } @SneakyThrows @@ -237,13 +237,13 @@ private void testPostFilterBoolQuery(String indexName) { HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( "mission", "part", - lte_of_range_in_hybrid_query, - gte_of_range_in_hybrid_query + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY ); QueryBuilder postFilterQuery = createQueryBuilderWithBoolShouldQuery( "impossible", - lte_of_range_in_post_filter_query, - gte_of_range_in_post_filter_query + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY ); Map searchResponseAsMap = search( @@ -255,7 +255,7 @@ private void testPostFilterBoolQuery(String indexName) { null, postFilterQuery ); - assertHybridQueryResults(searchResponseAsMap, 2, 1, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); + assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); // Case 2 /*{ "query": { @@ -311,7 +311,7 @@ private void testPostFilterBoolQuery(String indexName) { } } }*/ - AggregationBuilder aggsBuilder = createAggregations(); + AggregationBuilder aggsBuilder = createAvgAggregation(); searchResponseAsMap = search( indexName, hybridQueryBuilder, @@ -321,7 +321,7 @@ private void testPostFilterBoolQuery(String indexName) { List.of(aggsBuilder), postFilterQuery ); - assertHybridQueryResults(searchResponseAsMap, 2, 1, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); + assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); Map aggregations = getAggregations(searchResponseAsMap); assertNotNull(aggregations); @@ -379,8 +379,8 @@ private void testPostFilterBoolQuery(String indexName) { }*/ postFilterQuery = createQueryBuilderWithBoolMustQuery( "terminal", - lte_of_range_in_post_filter_query, - gte_of_range_in_post_filter_query + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY ); searchResponseAsMap = search( indexName, @@ -391,7 +391,7 @@ private void testPostFilterBoolQuery(String indexName) { null, postFilterQuery ); - assertHybridQueryResults(searchResponseAsMap, 0, 0, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); + assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); // Case 4 /*{ "query": { @@ -438,8 +438,8 @@ private void testPostFilterBoolQuery(String indexName) { hybridQueryBuilder = createHybridQueryBuilderScenarioWithMatchAndRangeQuery("hero", 5000, 1000); postFilterQuery = createQueryBuilderWithBoolShouldQuery( "impossible", - lte_of_range_in_post_filter_query, - gte_of_range_in_post_filter_query + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY ); searchResponseAsMap = search( indexName, @@ -450,11 +450,11 @@ private void testPostFilterBoolQuery(String indexName) { null, postFilterQuery ); - assertHybridQueryResults(searchResponseAsMap, 0, 0, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); + assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } @SneakyThrows - private void testPostFilterMatchAllAndNoneQuery(String indexName) { + private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { /*{ "query": { "hybrid": { @@ -489,8 +489,8 @@ private void testPostFilterMatchAllAndNoneQuery(String indexName) { HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( "mission", "part", - gte_of_range_in_hybrid_query, - lte_of_range_in_hybrid_query + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY ); QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(true); @@ -503,7 +503,7 @@ private void testPostFilterMatchAllAndNoneQuery(String indexName) { null, postFilterQuery ); - assertHybridQueryResults(searchResponseAsMap, 4, 3, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); + assertHybridQueryResults(searchResponseAsMap, 4, 3, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); /*{ "query": { @@ -546,7 +546,7 @@ private void testPostFilterMatchAllAndNoneQuery(String indexName) { null, postFilterQuery ); - assertHybridQueryResults(searchResponseAsMap, 0, 0, gte_of_range_in_post_filter_query, lte_of_range_in_post_filter_query); + assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } private void assertHybridQueryResults( @@ -726,15 +726,10 @@ private QueryBuilder createQueryBuilderWithBoolMustQuery(String query, int lte, } private QueryBuilder createPostFilterQueryBuilderWithMatchAllOrNoneQuery(boolean isMatchAll) { - if (isMatchAll) { - return QueryBuilders.matchAllQuery(); - } - - MatchNoneQueryBuilder matchNoneQueryBuilder = new MatchNoneQueryBuilder(); - return new MatchNoneQueryBuilder(); + return isMatchAll ? QueryBuilders.matchAllQuery() : new MatchNoneQueryBuilder(); } - private AggregationBuilder createAggregations() { + private AggregationBuilder createAvgAggregation() { return AggregationBuilders.avg(AVG_AGGREGATION_NAME).field(INTEGER_FIELD_1); } From 8910cbfdb20881be4f19388c8b28f5e1734a1bb0 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 19 Apr 2024 23:54:48 -0700 Subject: [PATCH 6/9] Resolving conflicts Signed-off-by: Varun Jain --- .../query/aggregation/BaseAggregationsWithHybridQueryIT.java | 2 +- .../query/aggregation/BucketAggregationsWithHybridQueryIT.java | 2 +- .../query/aggregation/MetricAggregationsWithHybridQueryIT.java | 2 +- .../aggregation/PipelineAggregationsWithHybridQueryIT.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java index 59d649919..521606fda 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java @@ -15,7 +15,7 @@ import java.util.Set; import java.util.stream.IntStream; -import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java index 0e59e4d53..ce8854eed 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java @@ -26,7 +26,7 @@ import java.util.List; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java index 5e1b00aba..36c853984 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java @@ -18,7 +18,7 @@ import java.util.List; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java index a83954510..168dce1e0 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java @@ -23,7 +23,7 @@ import java.util.List; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; From 9106cb3748186eb43d61d46f46e422c325b6f9a1 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Sat, 20 Apr 2024 00:13:44 -0700 Subject: [PATCH 7/9] Renaming Variables Signed-off-by: Varun Jain --- .../query/HybridQueryPostFilterIT.java | 143 +++++++++--------- 1 file changed, 75 insertions(+), 68 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index 0b7a90b25..5672d748e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -30,24 +30,24 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-hybrid-post-filter-multi-doc-index-single-shard"; private static final String SEARCH_PIPELINE = "phase-results-hybrid-post-filter-pipeline"; - private static final String INTEGER_FIELD_1 = "stock"; - private static final String TEXT_FIELD_1 = "name"; - private static final String KEYWORD_FIELD_2 = "category"; - private static final String TEXT_FIELD_VALUE_1 = "Dunes part 2"; - private static final String TEXT_FIELD_VALUE_2 = "Dunes part 1"; - private static final String TEXT_FIELD_VALUE_3 = "Mission Impossible 1"; - private static final String TEXT_FIELD_VALUE_4 = "Mission Impossible 2"; - private static final String TEXT_FIELD_VALUE_5 = "The Terminal"; - private static final String TEXT_FIELD_VALUE_6 = "Avengers"; - private static final int INTEGER_FIELD_STOCK_1_VALUE = 25; - private static final int INTEGER_FIELD_STOCK_2_VALUE = 22; - private static final int INTEGER_FIELD_STOCK_3_VALUE = 256; - private static final int INTEGER_FIELD_STOCK_4_VALUE = 25; - private static final int INTEGER_FIELD_STOCK_5_VALUE = 20; - private static final String KEYWORD_FIELD_CATEGORY_1_VALUE = "Drama"; - private static final String KEYWORD_FIELD_CATEGORY_2_VALUE = "Action"; - private static final String KEYWORD_FIELD_CATEGORY_3_VALUE = "Sci-fi"; - private static final String AVG_AGGREGATION_NAME = "avg_stock_size"; + private static final String INTEGER_FIELD_1_STOCK = "stock"; + private static final String TEXT_FIELD_1_NAME = "name"; + private static final String KEYWORD_FIELD_2_CATEGORY = "category"; + private static final String TEXT_FIELD_VALUE_1_DUNES = "Dunes part 2"; + private static final String TEXT_FIELD_VALUE_2_DUNES = "Dunes part 1"; + private static final String TEXT_FIELD_VALUE_3_MI_1 = "Mission Impossible 1"; + private static final String TEXT_FIELD_VALUE_4_MI_2 = "Mission Impossible 2"; + private static final String TEXT_FIELD_VALUE_5_TERMINAL = "The Terminal"; + private static final String TEXT_FIELD_VALUE_6_AVENGERS = "Avengers"; + private static final int INTEGER_FIELD_STOCK_1_25 = 25; + private static final int INTEGER_FIELD_STOCK_2_22 = 22; + private static final int INTEGER_FIELD_STOCK_3_256 = 256; + private static final int INTEGER_FIELD_STOCK_4_25 = 25; + private static final int INTEGER_FIELD_STOCK_5_20 = 20; + private static final String KEYWORD_FIELD_CATEGORY_1_DRAMA = "Drama"; + private static final String KEYWORD_FIELD_CATEGORY_2_ACTION = "Action"; + private static final String KEYWORD_FIELD_CATEGORY_3_SCI_FI = "Sci-fi"; + private static final String STOCK_AVG_AGGREGATION_NAME = "avg_stock_size"; private static boolean setUpIsDone = false; private static final int SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER = 1; private static final int SHARDS_COUNT_IN_MULTI_NODE_CLUSTER = 3; @@ -325,7 +325,7 @@ private void testPostFilterBoolQuery(String indexName) { Map aggregations = getAggregations(searchResponseAsMap); assertNotNull(aggregations); - Map aggValue = getAggregationValues(aggregations, AVG_AGGREGATION_NAME); + Map aggValue = getAggregationValues(aggregations, STOCK_AVG_AGGREGATION_NAME); assertEquals(1, aggValue.size()); // Case 3 /*{ @@ -563,7 +563,7 @@ private void assertHybridQueryResults( for (Map oneHit : hitsNestedList) { assertNotNull(oneHit.get("_source")); Map source = (Map) oneHit.get("_source"); - int docIndex = (int) source.get(INTEGER_FIELD_1); + int docIndex = (int) source.get(INTEGER_FIELD_1_STOCK); docIndexes.add(docIndex); } assertEquals(postFilterResultsValidationExpected, docIndexes.stream().filter(docIndex -> docIndex < lte || docIndex > gte).count()); @@ -584,7 +584,14 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { if (!indexExists(indexName)) { createIndexWithConfiguration( indexName, - buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_2), List.of(), numShards), + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(), + numShards + ), "" ); @@ -593,14 +600,14 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { "1", List.of(), List.of(), - Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_VALUE_1), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_1_DUNES), List.of(), List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_STOCK_1_VALUE), - List.of(KEYWORD_FIELD_2), - List.of(KEYWORD_FIELD_CATEGORY_1_VALUE), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_1_25), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), List.of(), List.of() ); @@ -610,14 +617,14 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { "2", List.of(), List.of(), - Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_VALUE_2), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_2_DUNES), List.of(), List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_STOCK_2_VALUE), - List.of(KEYWORD_FIELD_2), - List.of(KEYWORD_FIELD_CATEGORY_1_VALUE), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_2_22), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), List.of(), List.of() ); @@ -627,14 +634,14 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { "3", List.of(), List.of(), - Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_VALUE_3), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_3_MI_1), List.of(), List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_STOCK_3_VALUE), - List.of(KEYWORD_FIELD_2), - List.of(KEYWORD_FIELD_CATEGORY_2_VALUE), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_3_256), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_2_ACTION), List.of(), List.of() ); @@ -644,14 +651,14 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { "4", List.of(), List.of(), - Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_VALUE_4), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_4_MI_2), List.of(), List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_STOCK_4_VALUE), - List.of(KEYWORD_FIELD_2), - List.of(KEYWORD_FIELD_CATEGORY_2_VALUE), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_4_25), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_2_ACTION), List.of(), List.of() ); @@ -661,14 +668,14 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { "5", List.of(), List.of(), - Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_VALUE_5), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_5_TERMINAL), List.of(), List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_STOCK_5_VALUE), - List.of(KEYWORD_FIELD_2), - List.of(KEYWORD_FIELD_CATEGORY_1_VALUE), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), List.of(), List.of() ); @@ -678,14 +685,14 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { "6", List.of(), List.of(), - Collections.singletonList(TEXT_FIELD_1), - Collections.singletonList(TEXT_FIELD_VALUE_6), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_6_AVENGERS), List.of(), List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_STOCK_5_VALUE), - List.of(KEYWORD_FIELD_2), - List.of(KEYWORD_FIELD_CATEGORY_3_VALUE), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_3_SCI_FI), List.of(), List.of() ); @@ -693,35 +700,35 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { } private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1, text); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1, value); - RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(matchQueryBuilder).add(termQueryBuilder).add(rangeQueryBuilder); return hybridQueryBuilder; } private HybridQueryBuilder createHybridQueryBuilderScenarioWithMatchAndRangeQuery(String text, int lte, int gte) { - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1, text); - RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(matchQueryBuilder).add(rangeQueryBuilder); return hybridQueryBuilder; } private QueryBuilder createQueryBuilderWithRangeQuery(int lte, int gte) { - return QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); + return QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); } private QueryBuilder createQueryBuilderWithBoolShouldQuery(String query, int lte, int gte) { - QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); - QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1, query); + QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, query); return QueryBuilders.boolQuery().should(rangeQuery).should(matchQuery); } private QueryBuilder createQueryBuilderWithBoolMustQuery(String query, int lte, int gte) { - QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(gte).lte(lte); - QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1, query); + QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, query); return QueryBuilders.boolQuery().must(rangeQuery).must(matchQuery); } @@ -730,7 +737,7 @@ private QueryBuilder createPostFilterQueryBuilderWithMatchAllOrNoneQuery(boolean } private AggregationBuilder createAvgAggregation() { - return AggregationBuilders.avg(AVG_AGGREGATION_NAME).field(INTEGER_FIELD_1); + return AggregationBuilders.avg(STOCK_AVG_AGGREGATION_NAME).field(INTEGER_FIELD_1_STOCK); } } From 3871f54f973f1688c90e0d5ec2e05890343ca1f9 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 22 Apr 2024 10:31:07 -0700 Subject: [PATCH 8/9] Addressing Martin Comments Signed-off-by: Varun Jain --- .../query/HybridQueryPostFilterIT.java | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index 5672d748e..ecfed28c8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -9,7 +9,7 @@ import java.util.List; import java.util.Map; import lombok.SneakyThrows; -import org.junit.Before; +import org.junit.BeforeClass; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.MatchNoneQueryBuilder; @@ -33,8 +33,8 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final String INTEGER_FIELD_1_STOCK = "stock"; private static final String TEXT_FIELD_1_NAME = "name"; private static final String KEYWORD_FIELD_2_CATEGORY = "category"; - private static final String TEXT_FIELD_VALUE_1_DUNES = "Dunes part 2"; - private static final String TEXT_FIELD_VALUE_2_DUNES = "Dunes part 1"; + private static final String TEXT_FIELD_VALUE_1_DUNES = "Dunes part 1"; + private static final String TEXT_FIELD_VALUE_2_DUNES = "Dunes part 2"; private static final String TEXT_FIELD_VALUE_3_MI_1 = "Mission Impossible 1"; private static final String TEXT_FIELD_VALUE_4_MI_2 = "Mission Impossible 2"; private static final String TEXT_FIELD_VALUE_5_TERMINAL = "The Terminal"; @@ -56,16 +56,14 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final int LTE_OF_RANGE_IN_POST_FILTER_QUERY = 400; private static final int GTE_OF_RANGE_IN_POST_FILTER_QUERY = 230; - // @Before is a workaround to save extra update cluster settings call to the cluster. - // @BeforeClass throws RuntimeException with initializationError - @Before - public void setUp() throws Exception { - super.setUp(); - if (setUpIsDone) { - return; - } - updateClusterSettings(); - setUpIsDone = true; + @BeforeClass + @SneakyThrows + public static void setUpCluster() { + // we need new instance because we're calling non-static methods from static method. + // main purpose is to minimize network calls, initialization is only needed once + HybridQueryPostFilterIT instance = new HybridQueryPostFilterIT(); + instance.initClient(); + instance.updateClusterSettings(); } @SneakyThrows @@ -601,7 +599,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1_NAME), - Collections.singletonList(TEXT_FIELD_VALUE_1_DUNES), + Collections.singletonList(TEXT_FIELD_VALUE_2_DUNES), List.of(), List.of(), List.of(INTEGER_FIELD_1_STOCK), @@ -618,7 +616,7 @@ private void initializeIndexIfNotExists(String indexName, int numShards) { List.of(), List.of(), Collections.singletonList(TEXT_FIELD_1_NAME), - Collections.singletonList(TEXT_FIELD_VALUE_2_DUNES), + Collections.singletonList(TEXT_FIELD_VALUE_1_DUNES), List.of(), List.of(), List.of(INTEGER_FIELD_1_STOCK), From 64e9d4601c4bc379cd67e9090e72200b9de89f38 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 22 Apr 2024 11:11:35 -0700 Subject: [PATCH 9/9] Addressing Martin Comments Signed-off-by: Varun Jain --- .../query/HybridQueryPostFilterIT.java | 222 +++--------------- 1 file changed, 29 insertions(+), 193 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index ecfed28c8..7d33d07fe 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -48,7 +48,6 @@ public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { private static final String KEYWORD_FIELD_CATEGORY_2_ACTION = "Action"; private static final String KEYWORD_FIELD_CATEGORY_3_SCI_FI = "Sci-fi"; private static final String STOCK_AVG_AGGREGATION_NAME = "avg_stock_size"; - private static boolean setUpIsDone = false; private static final int SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER = 1; private static final int SHARDS_COUNT_IN_MULTI_NODE_CLUSTER = 3; private static final int LTE_OF_RANGE_IN_HYBRID_QUERY = 400; @@ -118,9 +117,7 @@ public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchDisabled } } - @SneakyThrows - private void testPostFilterRangeQuery(String indexName) { - /*{ + /*{ "query": { "hybrid":{ "queries":[ @@ -156,7 +153,9 @@ private void testPostFilterRangeQuery(String indexName) { } } } - }*/ + }*/ + @SneakyThrows + private void testPostFilterRangeQuery(String indexName) { HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( "mission", "part", @@ -180,10 +179,7 @@ private void testPostFilterRangeQuery(String indexName) { assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } - @SneakyThrows - private void testPostFilterBoolQuery(String indexName) { - // Case 1 - /*{ + /*{ "query": { "hybrid":{ "queries":[ @@ -210,6 +206,11 @@ private void testPostFilterBoolQuery(String indexName) { ] } + }, + "aggs": { + "avg_stock_size": { + "avg": { "field": "stock" } + } }, "post_filter":{ "bool":{ @@ -231,7 +232,11 @@ private void testPostFilterBoolQuery(String indexName) { ] } } - }*/ + }*/ + @SneakyThrows + private void testPostFilterBoolQuery(String indexName) { + // Case 1 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Range and a + // Match Query). HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( "mission", "part", @@ -254,61 +259,8 @@ private void testPostFilterBoolQuery(String indexName) { postFilterQuery ); assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); - // Case 2 - /*{ - "query": { - "hybrid":{ - "queries":[ - { - "match":{ - "name": "mission" - } - }, - { - "term":{ - "name":{ - "value":"part" - } - } - }, - { - "range": { - "stock": { - "gte": 200, - "lte": 400 - } - } - } - ] - } - - }, - "aggs": { - "avg_stock_size": { - "avg": { "field": "stock" } - } - }, - "post_filter":{ - "bool":{ - "should":[ - { - "range": { - "stock": { - "gte": 230, - "lte": 400 - } - } - }, - { - "match":{ - "name":"impossible" - } - } - - ] - } - } - }*/ + // Case 2 A Query with a combination of hybrid query (Match Query, Term Query, Range Query), aggregation (Average stock price + // `avg_stock_price`) and a post filter query (Range Query and a Match Query). AggregationBuilder aggsBuilder = createAvgAggregation(); searchResponseAsMap = search( indexName, @@ -325,56 +277,8 @@ private void testPostFilterBoolQuery(String indexName) { Map aggValue = getAggregationValues(aggregations, STOCK_AVG_AGGREGATION_NAME); assertEquals(1, aggValue.size()); - // Case 3 - /*{ - "query": { - "hybrid":{ - "queries":[ - { - "match":{ - "name": "mission" - } - }, - { - "term":{ - "name":{ - "value":"part" - } - } - }, - { - "range": { - "stock": { - "gte": 200, - "lte": 400 - } - } - } - ] - } - - }, - "post_filter":{ - "bool":{ - "must":[ - { - "range": { - "stock": { - "gte": 230, - "lte": 400 - } - } - }, - { - "match":{ - "name":"terminal" - } - } - - ] - } - } - }*/ + // Case 3 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Bool Query with + // a must clause(Range Query and a Match Query)). postFilterQuery = createQueryBuilderWithBoolMustQuery( "terminal", LTE_OF_RANGE_IN_POST_FILTER_QUERY, @@ -390,49 +294,8 @@ private void testPostFilterBoolQuery(String indexName) { postFilterQuery ); assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); - // Case 4 - /*{ - "query": { - "hybrid":{ - "queries":[ - { - "match":{ - "name": "hero" - } - }, - { - "range": { - "stock": { - "gte": 1000, - "lte": 5000 - } - } - } - ] - } - - }, - "post_filter":{ - "bool":{ - "should":[ - { - "range": { - "stock": { - "gte": 230, - "lte": 400 - } - } - }, - { - "match":{ - "name":"impossible" - } - } - - ] - } - } - }*/ + // Case 4 A Query with a combination of hybrid query (Match Query, Range Query) and a post filter query (Bool Query with a should + // clause(Range Query and a Match Query)). hybridQueryBuilder = createHybridQueryBuilderScenarioWithMatchAndRangeQuery("hero", 5000, 1000); postFilterQuery = createQueryBuilderWithBoolShouldQuery( "impossible", @@ -451,9 +314,7 @@ private void testPostFilterBoolQuery(String indexName) { assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); } - @SneakyThrows - private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { - /*{ + /*{ "query": { "hybrid": { "queries": [ @@ -483,7 +344,11 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { "post_filter": { "match_all": {} } - }*/ + }*/ + @SneakyThrows + private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { + // CASE 1 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Match ALL + // Query). HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( "mission", "part", @@ -503,37 +368,8 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { ); assertHybridQueryResults(searchResponseAsMap, 4, 3, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); - /*{ - "query": { - "hybrid": { - "queries": [ - { - "match": { - "name": "mission" - } - }, - { - "term": { - "name": { - "value": "part" - } - } - }, - { - "range": { - "stock": { - "gte": 200, - "lte": 400 - } - } - } - ] - } - }, - "post_filter": { - "match_none": {} - } - }*/ + // CASE 2 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Match NONE + // Query). postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(false); searchResponseAsMap = search( indexName,