|
82 | 82 | import org.opensearch.common.settings.Settings;
|
83 | 83 | import org.opensearch.common.util.concurrent.ThreadContext;
|
84 | 84 | import org.opensearch.core.action.ActionListener;
|
| 85 | +import org.opensearch.core.common.breaker.CircuitBreaker; |
| 86 | +import org.opensearch.core.common.breaker.CircuitBreakingException; |
85 | 87 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
86 | 88 | import org.opensearch.ml.breaker.MLCircuitBreakerService;
|
87 |
| -import org.opensearch.ml.breaker.MemoryCircuitBreaker; |
88 | 89 | import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
|
89 | 90 | import org.opensearch.ml.cluster.DiscoveryNodeHelper;
|
90 | 91 | import org.opensearch.ml.common.FunctionName;
|
|
115 | 116 | import org.opensearch.ml.stats.MLStats;
|
116 | 117 | import org.opensearch.ml.stats.suppliers.CounterSupplier;
|
117 | 118 | import org.opensearch.ml.task.MLTaskManager;
|
118 |
| -import org.opensearch.monitor.jvm.JvmService; |
119 | 119 | import org.opensearch.script.ScriptService;
|
120 | 120 | import org.opensearch.test.OpenSearchTestCase;
|
121 | 121 | import org.opensearch.threadpool.ThreadPool;
|
@@ -324,7 +324,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
|
324 | 324 | when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
|
325 | 325 | when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
|
326 | 326 | when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
|
327 |
| - expectedEx.expect(MLException.class); |
| 327 | + expectedEx.expect(CircuitBreakingException.class); |
328 | 328 | expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
|
329 | 329 | modelManager.registerMLModel(registerModelInput, mlTask);
|
330 | 330 | verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
|
@@ -453,21 +453,30 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
|
453 | 453 | verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
|
454 | 454 | }
|
455 | 455 |
|
456 |
| - public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() { |
| 456 | + public void testRegisterMLRemoteModel_SkipMemoryCBOpen() { |
457 | 457 | ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
|
458 |
| - MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class)); |
459 |
| - String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!"; |
460 |
| - when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage)); |
461 |
| - |
| 458 | + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); |
| 459 | + when(mlCircuitBreakerService.checkOpenCB()) |
| 460 | + .thenThrow( |
| 461 | + new CircuitBreakingException( |
| 462 | + "Memory Circuit Breaker is open, please check your resources!", |
| 463 | + CircuitBreaker.Durability.TRANSIENT |
| 464 | + ) |
| 465 | + ); |
| 466 | + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); |
| 467 | + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); |
462 | 468 | MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
|
463 | 469 | MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
|
| 470 | + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); |
| 471 | + doAnswer(invocation -> { |
| 472 | + ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1]; |
| 473 | + indexResponseActionListener.onResponse(indexResponse); |
| 474 | + return null; |
| 475 | + }).when(client).index(any(), any()); |
| 476 | + when(indexResponse.getId()).thenReturn("mockIndexId"); |
464 | 477 | modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
|
465 |
| - |
466 |
| - ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class); |
467 |
| - verify(listener, times(1)).onFailure(argCaptor.capture()); |
468 |
| - Exception e = argCaptor.getValue(); |
469 |
| - assertTrue(e instanceof MLLimitExceededException); |
470 |
| - assertEquals(memCBIsOpenMessage, e.getMessage()); |
| 478 | + assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); |
| 479 | + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); |
471 | 480 | }
|
472 | 481 |
|
473 | 482 | public void testIndexRemoteModel() throws PrivilegedActionException {
|
|
0 commit comments