|
8 | 8 | import java.nio.file.Path;
|
9 | 9 | import java.util.Locale;
|
10 | 10 | import java.util.Optional;
|
| 11 | +import java.util.Set; |
| 12 | + |
11 | 13 | import org.junit.Before;
|
12 | 14 | import org.opensearch.common.settings.Settings;
|
| 15 | +import org.opensearch.ml.common.model.MLModelState; |
13 | 16 | import org.opensearch.neuralsearch.BaseNeuralSearchIT;
|
14 | 17 | import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX;
|
15 | 18 | import static org.opensearch.neuralsearch.util.TestUtils.OLD_CLUSTER;
|
|
23 | 26 |
|
24 | 27 | public abstract class AbstractRollingUpgradeTestCase extends BaseNeuralSearchIT {
|
25 | 28 |
|
| 29 | + private static final Set<MLModelState> READY_FOR_INFERENCE_STATES = Set.of(MLModelState.LOADED, MLModelState.DEPLOYED); |
| 30 | + |
26 | 31 | @Before
|
27 | 32 | protected String getIndexNameForTest() {
|
28 | 33 | // Creating index name by concatenating "neural-bwc-" prefix with test method name
|
@@ -159,4 +164,24 @@ protected void createPipelineForTextChunkingProcessor(String pipelineName) throw
|
159 | 164 | );
|
160 | 165 | createPipelineProcessor(requestBody, pipelineName, "", null);
|
161 | 166 | }
|
| 167 | + |
| 168 | + protected boolean isModelReadyForInference(final MLModelState mlModelState) throws Exception { |
| 169 | + return READY_FOR_INFERENCE_STATES.contains(mlModelState); |
| 170 | + } |
| 171 | + |
| 172 | + protected void waitForModelToLoad(String modelId) throws Exception { |
| 173 | + int maxAttempts = 30; // Maximum number of attempts |
| 174 | + int waitTimeInSeconds = 2; // Time to wait between attempts |
| 175 | + |
| 176 | + for (int attempt = 0; attempt < maxAttempts; attempt++) { |
| 177 | + MLModelState state = getModelState(modelId); |
| 178 | + if (isModelReadyForInference(state)) { |
| 179 | + logger.info("Model {} is now loaded after {} attempts", modelId, attempt + 1); |
| 180 | + return; |
| 181 | + } |
| 182 | + logger.info("Waiting for model {} to load. Current state: {}. Attempt {}/{}", modelId, state, attempt + 1, maxAttempts); |
| 183 | + Thread.sleep(waitTimeInSeconds * 1000); |
| 184 | + } |
| 185 | + throw new RuntimeException("Model " + modelId + " failed to load after " + maxAttempts + " attempts"); |
| 186 | + } |
162 | 187 | }
|
0 commit comments