76
76
import org .opensearch .action .update .UpdateRequest ;
77
77
import org .opensearch .action .update .UpdateResponse ;
78
78
import org .opensearch .client .Client ;
79
+ import org .opensearch .cluster .service .ClusterApplierService ;
79
80
import org .opensearch .cluster .service .ClusterService ;
80
81
import org .opensearch .common .settings .ClusterSettings ;
81
82
import org .opensearch .common .settings .Settings ;
82
83
import org .opensearch .common .util .concurrent .ThreadContext ;
83
84
import org .opensearch .core .action .ActionListener ;
85
+ import org .opensearch .core .common .breaker .CircuitBreaker ;
86
+ import org .opensearch .core .common .breaker .CircuitBreakingException ;
84
87
import org .opensearch .core .xcontent .NamedXContentRegistry ;
85
88
import org .opensearch .ml .breaker .MLCircuitBreakerService ;
86
- import org .opensearch .ml .breaker .MemoryCircuitBreaker ;
87
89
import org .opensearch .ml .breaker .ThresholdCircuitBreaker ;
88
90
import org .opensearch .ml .cluster .DiscoveryNodeHelper ;
89
91
import org .opensearch .ml .common .FunctionName ;
114
116
import org .opensearch .ml .stats .MLStats ;
115
117
import org .opensearch .ml .stats .suppliers .CounterSupplier ;
116
118
import org .opensearch .ml .task .MLTaskManager ;
117
- import org .opensearch .monitor .jvm .JvmService ;
118
119
import org .opensearch .script .ScriptService ;
119
120
import org .opensearch .test .OpenSearchTestCase ;
120
121
import org .opensearch .threadpool .ThreadPool ;
@@ -177,7 +178,7 @@ public class MLModelManagerTests extends OpenSearchTestCase {
177
178
private ScriptService scriptService ;
178
179
179
180
@ Mock
180
- private MLTask pretrainedMLTask ;
181
+ ClusterApplierService clusterApplierService ;
181
182
182
183
@ Before
183
184
public void setup () throws URISyntaxException {
@@ -196,7 +197,7 @@ public void setup() throws URISyntaxException {
196
197
ML_COMMONS_MONITORING_REQUEST_COUNT ,
197
198
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE
198
199
);
199
- clusterService = spy (new ClusterService (settings , clusterSettings , null ));
200
+ clusterService = spy (new ClusterService (settings , clusterSettings , null , clusterApplierService ));
200
201
xContentRegistry = NamedXContentRegistry .EMPTY ;
201
202
202
203
modelName = "model_name1" ;
@@ -323,7 +324,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
323
324
when (mlCircuitBreakerService .checkOpenCB ()).thenReturn (thresholdCircuitBreaker );
324
325
when (thresholdCircuitBreaker .getName ()).thenReturn ("Disk Circuit Breaker" );
325
326
when (thresholdCircuitBreaker .getThreshold ()).thenReturn (87 );
326
- expectedEx .expect (MLException .class );
327
+ expectedEx .expect (CircuitBreakingException .class );
327
328
expectedEx .expectMessage ("Disk Circuit Breaker is open, please check your resources!" );
328
329
modelManager .registerMLModel (registerModelInput , mlTask );
329
330
verify (mlTaskManager ).updateMLTask (anyString (), anyMap (), anyLong (), anyBoolean ());
@@ -452,21 +453,30 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
452
453
verify (mlTaskManager ).updateMLTask (anyString (), anyMap (), anyLong (), anyBoolean ());
453
454
}
454
455
455
- public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail () {
456
+ public void testRegisterMLRemoteModel_SkipMemoryCBOpen () {
456
457
ActionListener <MLRegisterModelResponse > listener = mock (ActionListener .class );
457
- MemoryCircuitBreaker memCB = new MemoryCircuitBreaker (mock (JvmService .class ));
458
- String memCBIsOpenMessage = memCB .getName () + " is open, please check your resources!" ;
459
- when (mlCircuitBreakerService .checkOpenCB ()).thenThrow (new MLLimitExceededException (memCBIsOpenMessage ));
460
-
458
+ doNothing ().when (mlTaskManager ).checkLimitAndAddRunningTask (any (), any ());
459
+ when (mlCircuitBreakerService .checkOpenCB ())
460
+ .thenThrow (
461
+ new CircuitBreakingException (
462
+ "Memory Circuit Breaker is open, please check your resources!" ,
463
+ CircuitBreaker .Durability .TRANSIENT
464
+ )
465
+ );
466
+ when (threadPool .executor (REGISTER_THREAD_POOL )).thenReturn (taskExecutorService );
467
+ when (modelHelper .isModelAllowed (any (), any ())).thenReturn (true );
461
468
MLRegisterModelInput pretrainedInput = mockRemoteModelInput (true );
462
469
MLTask pretrainedTask = MLTask .builder ().taskId ("pretrained" ).modelId ("pretrained" ).functionName (FunctionName .REMOTE ).build ();
470
+ mock_MLIndicesHandler_initModelIndex (mlIndicesHandler , true );
471
+ doAnswer (invocation -> {
472
+ ActionListener <IndexResponse > indexResponseActionListener = (ActionListener <IndexResponse >) invocation .getArguments ()[1 ];
473
+ indexResponseActionListener .onResponse (indexResponse );
474
+ return null ;
475
+ }).when (client ).index (any (), any ());
476
+ when (indexResponse .getId ()).thenReturn ("mockIndexId" );
463
477
modelManager .registerMLRemoteModel (pretrainedInput , pretrainedTask , listener );
464
-
465
- ArgumentCaptor <Exception > argCaptor = ArgumentCaptor .forClass (Exception .class );
466
- verify (listener , times (1 )).onFailure (argCaptor .capture ());
467
- Exception e = argCaptor .getValue ();
468
- assertTrue (e instanceof MLLimitExceededException );
469
- assertEquals (memCBIsOpenMessage , e .getMessage ());
478
+ assertEquals (pretrainedTask .getFunctionName (), FunctionName .REMOTE );
479
+ verify (mlTaskManager ).updateMLTask (anyString (), anyMap (), anyLong (), anyBoolean ());
470
480
}
471
481
472
482
public void testIndexRemoteModel () throws PrivilegedActionException {
0 commit comments