|
43 | 43 | import org.opensearch.ml.common.dataframe.DataFrameBuilder;
|
44 | 44 | import org.opensearch.ml.common.dataset.DataFrameInputDataset;
|
45 | 45 | import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
|
| 46 | +import org.opensearch.ml.common.exception.MLLimitExceededException; |
46 | 47 | import org.opensearch.ml.common.exception.MLResourceNotFoundException;
|
47 | 48 | import org.opensearch.ml.common.input.MLInput;
|
48 | 49 | import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
|
@@ -235,6 +236,28 @@ public void testPrediction_MLResourceNotFoundException() {
|
235 | 236 | assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage());
|
236 | 237 | }
|
237 | 238 |
|
| 239 | + public void testPrediction_MLLimitExceededException() { |
| 240 | + when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); |
| 241 | + when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); |
| 242 | + |
| 243 | + doAnswer(invocation -> { |
| 244 | + ActionListener<Boolean> listener = invocation.getArgument(3); |
| 245 | + listener.onFailure(new MLLimitExceededException("Memory Circuit Breaker is open, please check your resources!")); |
| 246 | + return null; |
| 247 | + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); |
| 248 | + |
| 249 | + doAnswer(invocation -> { |
| 250 | + ((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null); |
| 251 | + return null; |
| 252 | + }).when(mlPredictTaskRunner).run(any(), any(), any(), any()); |
| 253 | + |
| 254 | + transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener); |
| 255 | + |
| 256 | + ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); |
| 257 | + verify(actionListener).onFailure(argumentCaptor.capture()); |
| 258 | + assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage()); |
| 259 | + } |
| 260 | + |
238 | 261 | public void testValidateInputSchemaSuccess() {
|
239 | 262 | RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet
|
240 | 263 | .builder()
|
|
0 commit comments