Skip to content

Commit 1c8cdfe

Browse files
committed
[Backport 2.19] exclude remote models in circuit breaker checks and fix memory CB bugs (opensearch-project#2713)
* exclude remote models in circuit breaker checks and fix memory CB bugs Signed-off-by: Xun Zhang <xunzh@amazon.com> * use static max heap threshold 100 Signed-off-by: Xun Zhang <xunzh@amazon.com> * fix issues after backport in 2.11 Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 5ecd0a0 commit 1c8cdfe

File tree

12 files changed

+102
-32
lines changed

12 files changed

+102
-32
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public void test_validateIp_validIp_noException() throws UnknownHostException {
3636
@Test
3737
public void test_validateIp_invalidIp_throwException() throws UnknownHostException {
3838
expectedException.expect(UnknownHostException.class);
39-
MLHttpClientFactory.validateIp("www.zaniu.com");
39+
MLHttpClientFactory.validateIp("www.zanniu.com");
4040
}
4141

4242
@Test

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+25-11
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55

66
package org.opensearch.ml.action.prediction;
77

8+
import org.opensearch.OpenSearchStatusException;
89
import org.opensearch.action.ActionListener;
910
import org.opensearch.action.ActionRequest;
1011
import org.opensearch.action.support.ActionFilters;
1112
import org.opensearch.action.support.HandledTransportAction;
1213
import org.opensearch.client.Client;
13-
import org.opensearch.cluster.service.ClusterService;
14+
import org.opensearch.common.breaker.CircuitBreakingException;
1415
import org.opensearch.common.inject.Inject;
1516
import org.opensearch.common.util.concurrent.ThreadContext;
1617
import org.opensearch.commons.authuser.User;
17-
import org.opensearch.core.xcontent.NamedXContentRegistry;
1818
import org.opensearch.ml.common.FunctionName;
1919
import org.opensearch.ml.common.MLModel;
20+
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
2021
import org.opensearch.ml.common.exception.MLValidationException;
2122
import org.opensearch.ml.common.transport.MLTaskResponse;
2223
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
@@ -27,6 +28,7 @@
2728
import org.opensearch.ml.task.MLPredictTaskRunner;
2829
import org.opensearch.ml.task.MLTaskRunner;
2930
import org.opensearch.ml.utils.RestActionUtils;
31+
import org.opensearch.rest.RestStatus;
3032
import org.opensearch.tasks.Task;
3133
import org.opensearch.transport.TransportService;
3234

@@ -43,10 +45,6 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action
4345

4446
Client client;
4547

46-
ClusterService clusterService;
47-
48-
NamedXContentRegistry xContentRegistry;
49-
5048
MLModelManager mlModelManager;
5149

5250
ModelAccessControlHelper modelAccessControlHelper;
@@ -57,19 +55,15 @@ public TransportPredictionTaskAction(
5755
ActionFilters actionFilters,
5856
MLPredictTaskRunner mlPredictTaskRunner,
5957
MLModelCacheHelper modelCacheHelper,
60-
ClusterService clusterService,
6158
Client client,
62-
NamedXContentRegistry xContentRegistry,
6359
MLModelManager mlModelManager,
6460
ModelAccessControlHelper modelAccessControlHelper
6561
) {
6662
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
6763
this.mlPredictTaskRunner = mlPredictTaskRunner;
6864
this.transportService = transportService;
6965
this.modelCacheHelper = modelCacheHelper;
70-
this.clusterService = clusterService;
7166
this.client = client;
72-
this.xContentRegistry = xContentRegistry;
7367
this.mlModelManager = mlModelManager;
7468
this.modelAccessControlHelper = modelAccessControlHelper;
7569
}
@@ -108,7 +102,27 @@ public void onResponse(MLModel mlModel) {
108102
}
109103
}, e -> {
110104
log.error("Failed to Validate Access for ModelId " + modelId, e);
111-
wrappedListener.onFailure(e);
105+
if (e instanceof OpenSearchStatusException) {
106+
wrappedListener
107+
.onFailure(
108+
new OpenSearchStatusException(
109+
e.getMessage(),
110+
RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus())
111+
)
112+
);
113+
} else if (e instanceof MLResourceNotFoundException) {
114+
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
115+
} else if (e instanceof CircuitBreakingException) {
116+
wrappedListener.onFailure(e);
117+
} else {
118+
wrappedListener
119+
.onFailure(
120+
new OpenSearchStatusException(
121+
"Failed to Validate Access for ModelId " + modelId,
122+
RestStatus.FORBIDDEN
123+
)
124+
);
125+
}
112126
}));
113127
}
114128

plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ public class MemoryCircuitBreaker extends ThresholdCircuitBreaker<Short> {
1818
// TODO: make this value configurable as cluster setting
1919
private static final String ML_MEMORY_CB = "Memory Circuit Breaker";
2020
public static final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = 85;
21+
public static final short JVM_HEAP_MAX_THRESHOLD = 100; // when threshold is 100, this CB check is ignored
2122
private final JvmService jvmService;
22-
private volatile Integer jvmHeapMemThreshold = 85;
2323

2424
public MemoryCircuitBreaker(JvmService jvmService) {
2525
super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD);
@@ -34,8 +34,9 @@ public MemoryCircuitBreaker(short threshold, JvmService jvmService) {
3434
public MemoryCircuitBreaker(Settings settings, ClusterService clusterService, JvmService jvmService) {
3535
super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD);
3636
this.jvmService = jvmService;
37-
this.jvmHeapMemThreshold = ML_COMMONS_JVM_HEAP_MEM_THRESHOLD.get(settings);
38-
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, it -> jvmHeapMemThreshold = it);
37+
clusterService
38+
.getClusterSettings()
39+
.addSettingsUpdateConsumer(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, it -> super.setThreshold(it.shortValue()));
3940
}
4041

4142
@Override
@@ -45,6 +46,6 @@ public String getName() {
4546

4647
@Override
4748
public boolean isOpen() {
48-
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
49+
return getThreshold() < JVM_HEAP_MAX_THRESHOLD && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
4950
}
5051
}

plugin/src/main/java/org/opensearch/ml/breaker/ThresholdCircuitBreaker.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
package org.opensearch.ml.breaker;
77

8+
import lombok.Data;
9+
810
/**
911
* An abstract class for all breakers with threshold.
1012
* @param <T> data type of threshold
1113
*/
14+
@Data
1215
public abstract class ThresholdCircuitBreaker<T> implements CircuitBreaker {
1316

1417
private T threshold;
@@ -17,10 +20,6 @@ public ThresholdCircuitBreaker(T threshold) {
1720
this.threshold = threshold;
1821
}
1922

20-
public T getThreshold() {
21-
return threshold;
22-
}
23-
2423
@Override
2524
public abstract boolean isOpen();
2625
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import java.util.HashMap;
5353
import java.util.List;
5454
import java.util.Map;
55+
import java.util.Objects;
5556
import java.util.Optional;
5657
import java.util.Set;
5758
import java.util.concurrent.ConcurrentLinkedDeque;
@@ -619,7 +620,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
619620
* @param runningTaskLimit limit
620621
*/
621622
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
622-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
623+
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
624+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
625+
}
623626
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
624627
}
625628

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void dispatchTask(
124124
if (clusterService.localNode().getId().equals(node.getId())) {
125125
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
126126
request.setDispatchTask(false);
127-
executeTask(request, listener);
127+
checkCBAndExecute(functionName, request, listener);
128128
} else {
129129
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
130130
request.setDispatchTask(false);

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
8787
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
8888
if (!request.isDispatchTask()) {
8989
log.debug("Run ML request {} locally", request.getRequestID());
90-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
91-
executeTask(request, listener);
90+
checkCBAndExecute(functionName, request, listener);
9291
return;
9392
}
9493
dispatchTask(functionName, request, transportService, listener);
@@ -129,4 +128,11 @@ public void dispatchTask(
129128
protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);
130129

131130
protected abstract void executeTask(Request request, ActionListener<Response> listener);
131+
132+
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
133+
if (functionName != FunctionName.REMOTE) {
134+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
135+
}
136+
executeTask(request, listener);
137+
}
132138
}

plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import java.util.function.Function;
1414

1515
import org.opensearch.cluster.node.DiscoveryNode;
16+
import org.opensearch.common.breaker.CircuitBreaker;
17+
import org.opensearch.common.breaker.CircuitBreakingException;
1618
import org.opensearch.common.bytes.BytesReference;
1719
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
1820
import org.opensearch.common.xcontent.XContentHelper;
@@ -21,7 +23,6 @@
2123
import org.opensearch.core.xcontent.XContentParser;
2224
import org.opensearch.ml.breaker.MLCircuitBreakerService;
2325
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
24-
import org.opensearch.ml.common.exception.MLLimitExceededException;
2526
import org.opensearch.ml.stats.MLNodeLevelStat;
2627
import org.opensearch.ml.stats.MLStats;
2728

@@ -60,7 +61,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea
6061
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
6162
if (openCircuitBreaker != null) {
6263
mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
63-
throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!");
64+
throw new CircuitBreakingException(
65+
openCircuitBreaker.getName() + " is open, please check your resources!",
66+
CircuitBreaker.Durability.TRANSIENT
67+
);
6468
}
6569
}
6670
}

plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java

+42
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
package org.opensearch.ml.breaker;
77

88
import static org.mockito.Mockito.when;
9+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD;
910

1011
import org.junit.Assert;
1112
import org.junit.Before;
1213
import org.junit.Test;
1314
import org.mockito.Mock;
1415
import org.mockito.MockitoAnnotations;
16+
import org.opensearch.cluster.service.ClusterService;
17+
import org.opensearch.common.settings.ClusterSettings;
18+
import org.opensearch.common.settings.Settings;
1519
import org.opensearch.monitor.jvm.JvmService;
1620
import org.opensearch.monitor.jvm.JvmStats;
1721

@@ -26,6 +30,9 @@ public class MemoryCircuitBreakerTests {
2630
@Mock
2731
JvmStats.Mem mem;
2832

33+
@Mock
34+
ClusterService clusterService;
35+
2936
@Before
3037
public void setup() {
3138
MockitoAnnotations.openMocks(this);
@@ -60,4 +67,39 @@ public void testIsOpen_CustomThreshold_ExceedMemoryThreshold() {
6067
when(mem.getHeapUsedPercent()).thenReturn((short) 95);
6168
Assert.assertTrue(breaker.isOpen());
6269
}
70+
71+
@Test
72+
public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
73+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
74+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
75+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
76+
77+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
78+
79+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
80+
Assert.assertTrue(breaker.isOpen());
81+
82+
Settings.Builder newSettingsBuilder = Settings.builder();
83+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 95);
84+
settingsService.applySettings(newSettingsBuilder.build());
85+
Assert.assertFalse(breaker.isOpen());
86+
}
87+
88+
@Test
89+
public void testIsOpen_DisableMemoryCB() {
90+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
91+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
92+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
93+
94+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
95+
96+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
97+
Assert.assertTrue(breaker.isOpen());
98+
99+
when(mem.getHeapUsedPercent()).thenReturn((short) 100);
100+
Settings.Builder newSettingsBuilder = Settings.builder();
101+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
102+
settingsService.applySettings(newSettingsBuilder.build());
103+
Assert.assertFalse(breaker.isOpen());
104+
}
63105
}

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.opensearch.action.update.UpdateResponse;
7474
import org.opensearch.client.Client;
7575
import org.opensearch.cluster.service.ClusterService;
76+
import org.opensearch.common.breaker.CircuitBreakingException;
7677
import org.opensearch.common.settings.ClusterSettings;
7778
import org.opensearch.common.settings.Settings;
7879
import org.opensearch.common.util.concurrent.ThreadContext;
@@ -311,7 +312,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
311312
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
312313
when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
313314
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
314-
expectedEx.expect(MLException.class);
315+
expectedEx.expect(CircuitBreakingException.class);
315316
expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
316317
modelManager.registerMLModel(registerModelInput, mlTask);
317318
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {
3838
+ " \"content_type\": \"application/json\",\n"
3939
+ " \"max_tokens\": 7,\n"
4040
+ " \"temperature\": 0,\n"
41-
+ " \"model\": \"text-davinci-003\"\n"
41+
+ " \"model\": \"davinci-002\"\n"
4242
+ " },\n"
4343
+ " \"credential\": {\n"
4444
+ " \"openAI_key\": \""
@@ -250,6 +250,7 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep
250250
assertNotNull(responseMap);
251251
}
252252

253+
@Ignore
253254
public void testOpenAIEditsModel() throws IOException, InterruptedException {
254255
// Skip test if key is null
255256
if (OPENAI_KEY == null) {

plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import org.opensearch.ml.common.MLTask;
3535
import org.opensearch.ml.common.MLTaskState;
3636
import org.opensearch.ml.common.MLTaskType;
37-
import org.opensearch.ml.common.exception.MLLimitExceededException;
3837
import org.opensearch.ml.common.transport.MLTaskRequest;
3938
import org.opensearch.ml.stats.MLNodeLevelStat;
4039
import org.opensearch.ml.stats.MLStat;
@@ -132,15 +131,15 @@ public void testHandleAsyncMLTaskComplete_SyncTask() {
132131
verify(mlTaskManager, never()).updateMLTask(eq(syncMlTask.getTaskId()), any(), anyLong(), anyBoolean());
133132
}
134133

135-
public void testRun_CircuitBreakerOpen() {
134+
public void testRemoteInferenceRun_CircuitBreakerNotOpen() {
136135
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
137136
when(thresholdCircuitBreaker.getName()).thenReturn("Memory Circuit Breaker");
138137
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
139138
TransportService transportService = mock(TransportService.class);
140139
ActionListener listener = mock(ActionListener.class);
141140
MLTaskRequest request = new MLTaskRequest(false);
142-
expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener));
141+
mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener);
143142
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
144-
assertEquals(1L, value.longValue());
143+
assertEquals(0L, value.longValue());
145144
}
146145
}

0 commit comments

Comments
 (0)