5
5
package org .opensearch .neuralsearch .bwc .rolling ;
6
6
7
7
import org .opensearch .neuralsearch .util .TestUtils ;
8
+ import org .opensearch .ml .common .model .MLModelState ;
8
9
9
10
import java .nio .file .Files ;
10
11
import java .nio .file .Path ;
@@ -28,29 +29,59 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
28
29
case OLD :
29
30
sparseModelId = uploadSparseEncodingModel ();
30
31
loadModel (sparseModelId );
32
+ MLModelState oldModelState = getModelState (sparseModelId );
33
+ logger .info ("Model state in OLD phase: {}" , oldModelState );
34
+ if (oldModelState != MLModelState .LOADED ) {
35
+ logger .error ("Model {} is not in LOADED state in OLD phase. Current state: {}" , sparseModelId , oldModelState );
36
+ waitForModelToLoad (sparseModelId );
37
+ }
31
38
createPipelineForSparseEncodingProcessor (sparseModelId , SPARSE_PIPELINE , 2 );
39
+ logger .info ("Pipeline state in OLD phase: {}" , getIngestionPipeline (SPARSE_PIPELINE ));
32
40
createIndexWithConfiguration (
33
41
indexName ,
34
42
Files .readString (Path .of (classLoader .getResource ("processor/SparseIndexMappings.json" ).toURI ())),
35
43
SPARSE_PIPELINE
36
44
);
37
45
List <Map <String , String >> docs = prepareDataForBulkIngestion (0 , 5 );
38
46
bulkAddDocuments (indexName , TEXT_FIELD_NAME , SPARSE_PIPELINE , docs );
47
+ logger .info ("Document count after OLD phase ingestion: {}" , getDocCount (indexName ));
39
48
validateDocCountAndInfo (indexName , 5 , () -> getDocById (indexName , "4" ), EMBEDDING_FIELD_NAME , Map .class );
40
49
break ;
41
50
case MIXED :
42
51
sparseModelId = TestUtils .getModelId (getIngestionPipeline (SPARSE_PIPELINE ), SPARSE_ENCODING_PROCESSOR );
43
52
loadModel (sparseModelId );
53
+ MLModelState mixedModelState = getModelState (sparseModelId );
54
+ logger .info ("Model state in MIXED phase: {}" , mixedModelState );
55
+ if (mixedModelState != MLModelState .LOADED ) {
56
+ logger .error ("Model {} is not in LOADED state in MIXED phase. Current state: {}" , sparseModelId , mixedModelState );
57
+ waitForModelToLoad (sparseModelId );
58
+ }
59
+ logger .info ("Pipeline state in MIXED phase: {}" , getIngestionPipeline (SPARSE_PIPELINE ));
44
60
List <Map <String , String >> docsForMixed = prepareDataForBulkIngestion (5 , 5 );
61
+ logger .info ("Document count before MIXED phase ingestion: {}" , getDocCount (indexName ));
45
62
bulkAddDocuments (indexName , TEXT_FIELD_NAME , SPARSE_PIPELINE , docsForMixed );
63
+ logger .info ("Document count after MIXED phase ingestion: {}" , getDocCount (indexName ));
46
64
validateDocCountAndInfo (indexName , 10 , () -> getDocById (indexName , "9" ), EMBEDDING_FIELD_NAME , Map .class );
47
65
break ;
48
66
case UPGRADED :
49
67
try {
50
68
sparseModelId = TestUtils .getModelId (getIngestionPipeline (SPARSE_PIPELINE ), SPARSE_ENCODING_PROCESSOR );
51
69
loadModel (sparseModelId );
70
+ MLModelState upgradedModelState = getModelState (sparseModelId );
71
+ logger .info ("Model state in UPGRADED phase: {}" , upgradedModelState );
72
+ if (upgradedModelState != MLModelState .LOADED ) {
73
+ logger .error (
74
+ "Model {} is not in LOADED state in UPGRADED phase. Current state: {}" ,
75
+ sparseModelId ,
76
+ upgradedModelState
77
+ );
78
+ waitForModelToLoad (sparseModelId );
79
+ }
80
+ logger .info ("Pipeline state in UPGRADED phase: {}" , getIngestionPipeline (SPARSE_PIPELINE ));
52
81
List <Map <String , String >> docsForUpgraded = prepareDataForBulkIngestion (10 , 5 );
82
+ logger .info ("Document count before UPGRADED phase ingestion: {}" , getDocCount (indexName ));
53
83
bulkAddDocuments (indexName , TEXT_FIELD_NAME , SPARSE_PIPELINE , docsForUpgraded );
84
+ logger .info ("Document count after UPGRADED phase ingestion: {}" , getDocCount (indexName ));
54
85
validateDocCountAndInfo (indexName , 15 , () -> getDocById (indexName , "14" ), EMBEDDING_FIELD_NAME , Map .class );
55
86
} finally {
56
87
wipeOfTestResources (indexName , SPARSE_PIPELINE , sparseModelId , null );
@@ -60,4 +91,20 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
60
91
throw new IllegalStateException ("Unexpected value: " + getClusterType ());
61
92
}
62
93
}
94
+
95
+ private void waitForModelToLoad (String modelId ) throws Exception {
96
+ int maxAttempts = 30 ; // Maximum number of attempts
97
+ int waitTimeInSeconds = 2 ; // Time to wait between attempts
98
+
99
+ for (int attempt = 0 ; attempt < maxAttempts ; attempt ++) {
100
+ MLModelState state = getModelState (modelId );
101
+ if (state == MLModelState .LOADED ) {
102
+ logger .info ("Model {} is now loaded after {} attempts" , modelId , attempt + 1 );
103
+ return ;
104
+ }
105
+ logger .info ("Waiting for model {} to load. Current state: {}. Attempt {}/{}" , modelId , state , attempt + 1 , maxAttempts );
106
+ Thread .sleep (waitTimeInSeconds * 1000 );
107
+ }
108
+ throw new RuntimeException ("Model " + modelId + " failed to load after " + maxAttempts + " attempts" );
109
+ }
63
110
}
0 commit comments