|
80 | 80 | import org.opensearch.common.settings.Settings;
|
81 | 81 | import org.opensearch.common.util.concurrent.ThreadContext;
|
82 | 82 | import org.opensearch.core.action.ActionListener;
|
| 83 | +import org.opensearch.core.common.breaker.CircuitBreaker; |
| 84 | +import org.opensearch.core.common.breaker.CircuitBreakingException; |
83 | 85 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
84 | 86 | import org.opensearch.ml.breaker.MLCircuitBreakerService;
|
85 | 87 | import org.opensearch.ml.breaker.MemoryCircuitBreaker;
|
@@ -322,7 +324,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
|
322 | 324 | when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
|
323 | 325 | when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
|
324 | 326 | when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
|
325 |
| - expectedEx.expect(MLException.class); |
| 327 | + expectedEx.expect(CircuitBreakingException.class); |
326 | 328 | expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
|
327 | 329 | modelManager.registerMLModel(registerModelInput, mlTask);
|
328 | 330 | verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
|
@@ -451,21 +453,32 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
|
451 | 453 | verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
|
452 | 454 | }
|
453 | 455 |
|
454 |
| - public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() { |
| 456 | + public void testRegisterMLRemoteModel_SkipMemoryCBOpen() { |
455 | 457 | ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
|
456 |
| - MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class)); |
457 |
| - String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!"; |
458 |
| - when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage)); |
| 458 | + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); |
| 459 | + when(mlCircuitBreakerService.checkOpenCB()) |
| 460 | + .thenThrow( |
| 461 | + new CircuitBreakingException( |
| 462 | + "Memory Circuit Breaker is open, please check your resources!", |
| 463 | + CircuitBreaker.Durability.TRANSIENT |
| 464 | + ) |
| 465 | + ); |
| 466 | + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); |
| 467 | + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); |
459 | 468 |
|
460 | 469 | MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
|
461 | 470 | MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
|
| 471 | + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); |
| 472 | + doAnswer(invocation -> { |
| 473 | + ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1]; |
| 474 | + indexResponseActionListener.onResponse(indexResponse); |
| 475 | + return null; |
| 476 | + }).when(client).index(any(), any()); |
| 477 | + when(indexResponse.getId()).thenReturn("mockIndexId"); |
462 | 478 | modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
|
463 | 479 |
|
464 |
| - ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class); |
465 |
| - verify(listener, times(1)).onFailure(argCaptor.capture()); |
466 |
| - Exception e = argCaptor.getValue(); |
467 |
| - assertTrue(e instanceof MLLimitExceededException); |
468 |
| - assertEquals(memCBIsOpenMessage, e.getMessage()); |
| 480 | + assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); |
| 481 | + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); |
469 | 482 | }
|
470 | 483 |
|
471 | 484 | public void testIndexRemoteModel() throws PrivilegedActionException {
|
|
0 commit comments