|
30 | 30 | import org.opensearch.common.settings.Settings;
|
31 | 31 | import org.opensearch.core.action.ActionListener;
|
32 | 32 | import org.opensearch.core.common.Strings;
|
| 33 | +import org.opensearch.core.rest.RestStatus; |
33 | 34 | import org.opensearch.ml.common.FunctionName;
|
34 | 35 | import org.opensearch.ml.common.input.Input;
|
35 | 36 | import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
|
|
44 | 45 | import org.opensearch.test.OpenSearchTestCase;
|
45 | 46 | import org.opensearch.threadpool.TestThreadPool;
|
46 | 47 | import org.opensearch.threadpool.ThreadPool;
|
| 48 | +import org.opensearch.transport.RemoteTransportException; |
47 | 49 |
|
48 | 50 | public class RestMLExecuteActionTests extends OpenSearchTestCase {
|
49 | 51 |
|
@@ -206,4 +208,77 @@ public void testPrepareRequest_disabled() {
|
206 | 208 | when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false);
|
207 | 209 | assertThrows(IllegalStateException.class, () -> restMLExecuteAction.handleRequest(request, channel, client));
|
208 | 210 | }
|
| 211 | + |
| 212 | + public void testPrepareRequestClientException() throws Exception { |
| 213 | + doAnswer(invocation -> { |
| 214 | + ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2); |
| 215 | + actionListener.onFailure(new IllegalArgumentException("Illegal Argument Exception")); |
| 216 | + return null; |
| 217 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); |
| 218 | + doNothing().when(channel).sendResponse(any()); |
| 219 | + RestRequest request = getLocalSampleCalculatorRestRequest(); |
| 220 | + restMLExecuteAction.handleRequest(request, channel, client); |
| 221 | + |
| 222 | + ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); |
| 223 | + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); |
| 224 | + Input input = argumentCaptor.getValue().getInput(); |
| 225 | + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); |
| 226 | + ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); |
| 227 | + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); |
| 228 | + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); |
| 229 | + assertEquals(RestStatus.BAD_REQUEST, response.status()); |
| 230 | + String content = response.content().utf8ToString(); |
| 231 | + String expectedError = |
| 232 | + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; |
| 233 | + assertEquals(expectedError, response.content().utf8ToString()); |
| 234 | + } |
| 235 | + |
| 236 | + public void testPrepareRequestClientWrappedException() throws Exception { |
| 237 | + doAnswer(invocation -> { |
| 238 | + ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2); |
| 239 | + actionListener |
| 240 | + .onFailure( |
| 241 | + new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception")) |
| 242 | + ); |
| 243 | + return null; |
| 244 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); |
| 245 | + doNothing().when(channel).sendResponse(any()); |
| 246 | + RestRequest request = getLocalSampleCalculatorRestRequest(); |
| 247 | + restMLExecuteAction.handleRequest(request, channel, client); |
| 248 | + |
| 249 | + ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); |
| 250 | + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); |
| 251 | + Input input = argumentCaptor.getValue().getInput(); |
| 252 | + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); |
| 253 | + ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); |
| 254 | + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); |
| 255 | + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); |
| 256 | + assertEquals(RestStatus.BAD_REQUEST, response.status()); |
| 257 | + String expectedError = |
| 258 | + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; |
| 259 | + assertEquals(expectedError, response.content().utf8ToString()); |
| 260 | + } |
| 261 | + |
| 262 | + public void testPrepareRequestSystemException() throws Exception { |
| 263 | + doAnswer(invocation -> { |
| 264 | + ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2); |
| 265 | + actionListener.onFailure(new RuntimeException("System Exception")); |
| 266 | + return null; |
| 267 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); |
| 268 | + doNothing().when(channel).sendResponse(any()); |
| 269 | + RestRequest request = getLocalSampleCalculatorRestRequest(); |
| 270 | + restMLExecuteAction.handleRequest(request, channel, client); |
| 271 | + |
| 272 | + ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); |
| 273 | + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); |
| 274 | + Input input = argumentCaptor.getValue().getInput(); |
| 275 | + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); |
| 276 | + ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); |
| 277 | + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); |
| 278 | + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); |
| 279 | + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, response.status()); |
| 280 | + String expectedError = |
| 281 | + "{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}"; |
| 282 | + assertEquals(expectedError, response.content().utf8ToString()); |
| 283 | + } |
209 | 284 | }
|
0 commit comments