Skip to content

Commit 1da79ce

Browse files
authored
send response in xcontent, if any exception, use plain text (opensearch-project#2858)
Signed-off-by: Jing Zhang <jngz@amazon.com>
1 parent 7ecff1a commit 1da79ce

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.opensearch.client.node.NodeClient;
2222
import org.opensearch.core.action.ActionListener;
2323
import org.opensearch.core.rest.RestStatus;
24+
import org.opensearch.core.xcontent.XContentBuilder;
2425
import org.opensearch.core.xcontent.XContentParser;
2526
import org.opensearch.ml.common.FunctionName;
2627
import org.opensearch.ml.common.input.Input;
@@ -132,7 +133,21 @@ private void sendResponse(RestChannel channel, MLExecuteTaskResponse response) t
132133

133134
private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
134135
ErrorMessage errorMessage = ErrorMessageFactory.createErrorMessage(e, status.getStatus());
135-
channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), errorMessage.toString()));
136+
try {
137+
XContentBuilder builder = channel.newBuilder();
138+
builder.startObject();
139+
builder.field("status", errorMessage.getStatus());
140+
builder.startObject("error");
141+
builder.field("type", errorMessage.getType());
142+
builder.field("reason", errorMessage.getReason());
143+
builder.field("details", errorMessage.getDetails());
144+
builder.endObject();
145+
builder.endObject();
146+
channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), builder));
147+
} catch (Exception exception) {
148+
log.error("Failed to build xContent for an error response, so reply with a plain string.", exception);
149+
channel.sendResponse(new BytesRestResponse(RestStatus.fromCode(errorMessage.getStatus()), errorMessage.toString()));
150+
}
136151
}
137152

138153
private boolean isClientError(Exception e) {

plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java

+56
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.mockito.MockitoAnnotations;
2929
import org.opensearch.client.node.NodeClient;
3030
import org.opensearch.common.settings.Settings;
31+
import org.opensearch.common.xcontent.XContentFactory;
3132
import org.opensearch.core.action.ActionListener;
3233
import org.opensearch.core.common.Strings;
3334
import org.opensearch.core.rest.RestStatus;
@@ -281,4 +282,59 @@ public void testPrepareRequestSystemException() throws Exception {
281282
"{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}";
282283
assertEquals(expectedError, response.content().utf8ToString());
283284
}
285+
286+
public void testAgentExecutionResponseXContent() throws Exception {
287+
RestRequest request = getExecuteAgentRestRequest();
288+
doAnswer(invocation -> {
289+
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
290+
actionListener
291+
.onFailure(
292+
new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception"))
293+
);
294+
return null;
295+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
296+
doNothing().when(channel).sendResponse(any());
297+
when(channel.newBuilder()).thenReturn(XContentFactory.jsonBuilder());
298+
restMLExecuteAction.handleRequest(request, channel, client);
299+
300+
ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
301+
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
302+
Input input = argumentCaptor.getValue().getInput();
303+
assertEquals(FunctionName.AGENT, input.getFunctionName());
304+
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
305+
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
306+
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
307+
assertEquals(RestStatus.BAD_REQUEST, response.status());
308+
assertEquals("application/json; charset=UTF-8", response.contentType());
309+
String expectedError =
310+
"{\"status\":400,\"error\":{\"type\":\"IllegalArgumentException\",\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\"}}";
311+
assertEquals(expectedError, response.content().utf8ToString());
312+
}
313+
314+
public void testAgentExecutionResponsePlainText() throws Exception {
315+
RestRequest request = getExecuteAgentRestRequest();
316+
doAnswer(invocation -> {
317+
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
318+
actionListener
319+
.onFailure(
320+
new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception"))
321+
);
322+
return null;
323+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
324+
doNothing().when(channel).sendResponse(any());
325+
restMLExecuteAction.handleRequest(request, channel, client);
326+
327+
ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
328+
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
329+
Input input = argumentCaptor.getValue().getInput();
330+
assertEquals(FunctionName.AGENT, input.getFunctionName());
331+
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
332+
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
333+
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
334+
assertEquals(RestStatus.BAD_REQUEST, response.status());
335+
assertEquals("text/plain; charset=UTF-8", response.contentType());
336+
String expectedError =
337+
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
338+
assertEquals(expectedError, response.content().utf8ToString());
339+
}
284340
}

0 commit comments

Comments
 (0)