|
28 | 28 | import org.mockito.MockitoAnnotations;
|
29 | 29 | import org.opensearch.client.node.NodeClient;
|
30 | 30 | import org.opensearch.common.settings.Settings;
|
| 31 | +import org.opensearch.common.xcontent.XContentFactory; |
31 | 32 | import org.opensearch.core.action.ActionListener;
|
32 | 33 | import org.opensearch.core.common.Strings;
|
33 | 34 | import org.opensearch.core.rest.RestStatus;
|
@@ -281,4 +282,59 @@ public void testPrepareRequestSystemException() throws Exception {
|
281 | 282 | "{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}";
|
282 | 283 | assertEquals(expectedError, response.content().utf8ToString());
|
283 | 284 | }
|
| 285 | + |
| 286 | + public void testAgentExecutionResponseXContent() throws Exception { |
| 287 | + RestRequest request = getExecuteAgentRestRequest(); |
| 288 | + doAnswer(invocation -> { |
| 289 | + ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2); |
| 290 | + actionListener |
| 291 | + .onFailure( |
| 292 | + new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception")) |
| 293 | + ); |
| 294 | + return null; |
| 295 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); |
| 296 | + doNothing().when(channel).sendResponse(any()); |
| 297 | + when(channel.newBuilder()).thenReturn(XContentFactory.jsonBuilder()); |
| 298 | + restMLExecuteAction.handleRequest(request, channel, client); |
| 299 | + |
| 300 | + ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); |
| 301 | + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); |
| 302 | + Input input = argumentCaptor.getValue().getInput(); |
| 303 | + assertEquals(FunctionName.AGENT, input.getFunctionName()); |
| 304 | + ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); |
| 305 | + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); |
| 306 | + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); |
| 307 | + assertEquals(RestStatus.BAD_REQUEST, response.status()); |
| 308 | + assertEquals("application/json; charset=UTF-8", response.contentType()); |
| 309 | + String expectedError = |
| 310 | + "{\"status\":400,\"error\":{\"type\":\"IllegalArgumentException\",\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\"}}"; |
| 311 | + assertEquals(expectedError, response.content().utf8ToString()); |
| 312 | + } |
| 313 | + |
| 314 | + public void testAgentExecutionResponsePlainText() throws Exception { |
| 315 | + RestRequest request = getExecuteAgentRestRequest(); |
| 316 | + doAnswer(invocation -> { |
| 317 | + ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2); |
| 318 | + actionListener |
| 319 | + .onFailure( |
| 320 | + new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception")) |
| 321 | + ); |
| 322 | + return null; |
| 323 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); |
| 324 | + doNothing().when(channel).sendResponse(any()); |
| 325 | + restMLExecuteAction.handleRequest(request, channel, client); |
| 326 | + |
| 327 | + ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); |
| 328 | + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); |
| 329 | + Input input = argumentCaptor.getValue().getInput(); |
| 330 | + assertEquals(FunctionName.AGENT, input.getFunctionName()); |
| 331 | + ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); |
| 332 | + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); |
| 333 | + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); |
| 334 | + assertEquals(RestStatus.BAD_REQUEST, response.status()); |
| 335 | + assertEquals("text/plain; charset=UTF-8", response.contentType()); |
| 336 | + String expectedError = |
| 337 | + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; |
| 338 | + assertEquals(expectedError, response.content().utf8ToString()); |
| 339 | + } |
284 | 340 | }
|
0 commit comments