Skip to content

Commit 2211563

Browse files
committed
change CB limit exception code to 429 and skip CB check for remote models
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 70ea17f commit 2211563

File tree

8 files changed

+31
-21
lines changed

8 files changed

+31
-21
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
import org.opensearch.common.xcontent.XContentFactory;
2121
import org.opensearch.commons.authuser.User;
2222
import org.opensearch.core.action.ActionListener;
23+
import org.opensearch.core.common.breaker.CircuitBreakingException;
2324
import org.opensearch.core.rest.RestStatus;
2425
import org.opensearch.core.xcontent.NamedXContentRegistry;
2526
import org.opensearch.core.xcontent.ToXContent;
2627
import org.opensearch.ml.common.FunctionName;
2728
import org.opensearch.ml.common.MLModel;
28-
import org.opensearch.ml.common.exception.MLLimitExceededException;
2929
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
3030
import org.opensearch.ml.common.input.MLInput;
3131
import org.opensearch.ml.common.transport.MLTaskResponse;
@@ -178,8 +178,8 @@ public void onResponse(MLModel mlModel) {
178178
);
179179
} else if (e instanceof MLResourceNotFoundException) {
180180
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
181-
} else if (e instanceof MLLimitExceededException) {
182-
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE));
181+
} else if (e instanceof CircuitBreakingException) {
182+
wrappedListener.onFailure(e);
183183
} else {
184184
wrappedListener
185185
.onFailure(

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import java.util.HashMap;
6060
import java.util.List;
6161
import java.util.Map;
62+
import java.util.Objects;
6263
import java.util.Optional;
6364
import java.util.Set;
6465
import java.util.concurrent.ConcurrentLinkedDeque;
@@ -838,7 +839,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
838839
* @param runningTaskLimit limit
839840
*/
840841
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
841-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
842+
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
843+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
844+
}
842845
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
843846
}
844847

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-7
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,7 @@ public void dispatchTask(
143143
if (clusterService.localNode().getId().equals(node.getId())) {
144144
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
145145
request.setDispatchTask(false);
146-
run(
147-
// This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here
148-
functionName,
149-
request,
150-
transportService,
151-
listener
152-
);
146+
checkCBAndExecute(functionName, request, listener);
153147
} else {
154148
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
155149
request.setDispatchTask(false);

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
8787
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
8888
if (!request.isDispatchTask()) {
8989
log.debug("Run ML request {} locally", request.getRequestID());
90-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
91-
executeTask(request, listener);
90+
checkCBAndExecute(functionName, request, listener);
9291
return;
9392
}
9493
dispatchTask(functionName, request, transportService, listener);
@@ -129,4 +128,11 @@ public void dispatchTask(
129128
protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);
130129

131130
protected abstract void executeTask(Request request, ActionListener<Response> listener);
131+
132+
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
133+
if (functionName != FunctionName.REMOTE) {
134+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
135+
}
136+
executeTask(request, listener);
137+
}
132138
}

plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
1919
import org.opensearch.common.xcontent.XContentHelper;
2020
import org.opensearch.common.xcontent.XContentType;
21+
import org.opensearch.core.common.breaker.CircuitBreaker;
22+
import org.opensearch.core.common.breaker.CircuitBreakingException;
2123
import org.opensearch.core.common.bytes.BytesReference;
2224
import org.opensearch.core.xcontent.NamedXContentRegistry;
2325
import org.opensearch.core.xcontent.XContentParser;
2426
import org.opensearch.ml.breaker.MLCircuitBreakerService;
2527
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
26-
import org.opensearch.ml.common.exception.MLLimitExceededException;
2728
import org.opensearch.ml.stats.MLNodeLevelStat;
2829
import org.opensearch.ml.stats.MLStats;
2930

@@ -92,7 +93,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea
9293
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
9394
if (openCircuitBreaker != null) {
9495
mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
95-
throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!");
96+
throw new CircuitBreakingException(
97+
openCircuitBreaker.getName() + " is open, please check your resources!",
98+
CircuitBreaker.Durability.TRANSIENT
99+
);
96100
}
97101
}
98102
}

plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import org.opensearch.common.util.concurrent.ThreadContext;
3636
import org.opensearch.commons.authuser.User;
3737
import org.opensearch.core.action.ActionListener;
38+
import org.opensearch.core.common.breaker.CircuitBreaker;
39+
import org.opensearch.core.common.breaker.CircuitBreakingException;
3840
import org.opensearch.core.rest.RestStatus;
3941
import org.opensearch.core.xcontent.NamedXContentRegistry;
4042
import org.opensearch.ml.common.FunctionName;
@@ -43,7 +45,6 @@
4345
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
4446
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
4547
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
46-
import org.opensearch.ml.common.exception.MLLimitExceededException;
4748
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
4849
import org.opensearch.ml.common.input.MLInput;
4950
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
@@ -242,7 +243,7 @@ public void testPrediction_MLLimitExceededException() {
242243

243244
doAnswer(invocation -> {
244245
ActionListener<Boolean> listener = invocation.getArgument(3);
245-
listener.onFailure(new MLLimitExceededException("Memory Circuit Breaker is open, please check your resources!"));
246+
listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT));
246247
return null;
247248
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());
248249

@@ -253,7 +254,7 @@ public void testPrediction_MLLimitExceededException() {
253254

254255
transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener);
255256

256-
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
257+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(CircuitBreakingException.class);
257258
verify(actionListener).onFailure(argumentCaptor.capture());
258259
assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage());
259260
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public void testRunWithMemoryCircuitBreaker() throws IOException {
5858
exception.getMessage(),
5959
allOf(
6060
containsString("Memory Circuit Breaker is open, please check your resources!"),
61-
containsString("m_l_limit_exceeded_exception")
61+
containsString("circuit_breaking_exception")
6262
)
6363
);
6464

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, In
199199
Response response = createConnector(completionModelConnectorEntity);
200200
Map responseMap = parseResponseToMap(response);
201201
String connectorId = (String) responseMap.get("connector_id");
202-
response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1);
202+
response = registerRemoteModelWithTTLAndSkipHeapMemCheck("openAI-GPT-3.5 completions", connectorId, 1);
203203
responseMap = parseResponseToMap(response);
204204
String modelId = (String) responseMap.get("model_id");
205205
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
@@ -814,11 +814,13 @@ public static Response registerRemoteModel(String name, String connectorId) thro
814814
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
815815
}
816816

817-
public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException {
817+
public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name, String connectorId, int ttl) throws IOException {
818818
String registerModelGroupEntity = "{\n"
819819
+ " \"name\": \"remote_model_group\",\n"
820820
+ " \"description\": \"This is an example description\"\n"
821821
+ "}";
822+
String updateJVMHeapThreshold = "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":0}}";
823+
TestHelper.makeRequest(client(), "PUT", "/_cluster/settings", null, TestHelper.toHttpEntity(updateJVMHeapThreshold), null);
822824
Response response = TestHelper
823825
.makeRequest(
824826
client(),

0 commit comments

Comments
 (0)