|
80 | 80 | import org.opensearch.common.settings.Settings;
|
81 | 81 | import org.opensearch.common.util.concurrent.ThreadContext;
|
82 | 82 | import org.opensearch.core.action.ActionListener;
|
| 83 | +import org.opensearch.core.common.breaker.CircuitBreaker; |
| 84 | +import org.opensearch.core.common.breaker.CircuitBreakingException; |
83 | 85 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
84 | 86 | import org.opensearch.ml.breaker.MLCircuitBreakerService;
|
85 |
| -import org.opensearch.ml.breaker.MemoryCircuitBreaker; |
86 | 87 | import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
|
87 | 88 | import org.opensearch.ml.cluster.DiscoveryNodeHelper;
|
88 | 89 | import org.opensearch.ml.common.FunctionName;
|
|
113 | 114 | import org.opensearch.ml.stats.MLStats;
|
114 | 115 | import org.opensearch.ml.stats.suppliers.CounterSupplier;
|
115 | 116 | import org.opensearch.ml.task.MLTaskManager;
|
116 |
| -import org.opensearch.monitor.jvm.JvmService; |
117 | 117 | import org.opensearch.script.ScriptService;
|
118 | 118 | import org.opensearch.test.OpenSearchTestCase;
|
119 | 119 | import org.opensearch.threadpool.ThreadPool;
|
@@ -322,7 +322,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
|
322 | 322 | when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
|
323 | 323 | when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
|
324 | 324 | when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
|
325 |
| - expectedEx.expect(MLException.class); |
| 325 | + expectedEx.expect(CircuitBreakingException.class); |
326 | 326 | expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
|
327 | 327 | modelManager.registerMLModel(registerModelInput, mlTask);
|
328 | 328 | verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
|
@@ -451,21 +451,32 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
|
451 | 451 | verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
|
452 | 452 | }
|
453 | 453 |
|
454 |
| - public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() { |
| 454 | + public void testRegisterMLRemoteModel_SkipMemoryCBOpen() { |
455 | 455 | ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
|
456 |
| - MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class)); |
457 |
| - String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!"; |
458 |
| - when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage)); |
| 456 | + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); |
| 457 | + when(mlCircuitBreakerService.checkOpenCB()) |
| 458 | + .thenThrow( |
| 459 | + new CircuitBreakingException( |
| 460 | + "Memory Circuit Breaker is open, please check your resources!", |
| 461 | + CircuitBreaker.Durability.TRANSIENT |
| 462 | + ) |
| 463 | + ); |
| 464 | + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); |
| 465 | + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); |
459 | 466 |
|
460 | 467 | MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
|
461 | 468 | MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
|
| 469 | + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); |
| 470 | + doAnswer(invocation -> { |
| 471 | + ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1]; |
| 472 | + indexResponseActionListener.onResponse(indexResponse); |
| 473 | + return null; |
| 474 | + }).when(client).index(any(), any()); |
| 475 | + when(indexResponse.getId()).thenReturn("mockIndexId"); |
462 | 476 | modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
|
463 | 477 |
|
464 |
| - ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class); |
465 |
| - verify(listener, times(1)).onFailure(argCaptor.capture()); |
466 |
| - Exception e = argCaptor.getValue(); |
467 |
| - assertTrue(e instanceof MLLimitExceededException); |
468 |
| - assertEquals(memCBIsOpenMessage, e.getMessage()); |
| 478 | + assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); |
| 479 | + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); |
469 | 480 | }
|
470 | 481 |
|
471 | 482 | public void testIndexRemoteModel() throws PrivilegedActionException {
|
|
0 commit comments