Skip to content

Commit a2f1cd3

Browse files
fix error message with unwrapping the root cause (opensearch-project#2458) (opensearch-project#2515)
Signed-off-by: Jing Zhang <jngz@amazon.com> (cherry picked from commit 8331fe6) Co-authored-by: Jing Zhang <jngz@amazon.com>
1 parent 20adeec commit a2f1cd3

File tree

4 files changed

+102
-37
lines changed

4 files changed

+102
-37
lines changed

plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java

+4
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,8 @@ public static void logException(String errorMessage, Exception e, Logger log) {
5959
log.error(errorMessage, e);
6060
}
6161
}
62+
63+
public static Throwable getRootCause(Throwable t) {
64+
return ExceptionUtils.getRootCause(t);
65+
}
6266
}

plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java

+2-14
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.utils.error;
77

88
import org.opensearch.OpenSearchException;
9+
import org.opensearch.ml.utils.MLExceptionUtils;
910

1011
import lombok.experimental.UtilityClass;
1112

@@ -23,22 +24,9 @@ public static ErrorMessage createErrorMessage(Throwable e, int status) {
2324
int st = status;
2425
if (t instanceof OpenSearchException) {
2526
st = ((OpenSearchException) t).status().getStatus();
26-
} else {
27-
t = unwrapCause(e);
2827
}
28+
t = MLExceptionUtils.getRootCause(t);
2929

3030
return new ErrorMessage(t, st);
3131
}
32-
33-
protected static Throwable unwrapCause(Throwable t) {
34-
Throwable result = t;
35-
if (result instanceof OpenSearchException) {
36-
return result;
37-
}
38-
if (result.getCause() == null) {
39-
return result;
40-
}
41-
result = unwrapCause(result.getCause());
42-
return result;
43-
}
4432
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java

+75
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.opensearch.common.settings.Settings;
3131
import org.opensearch.core.action.ActionListener;
3232
import org.opensearch.core.common.Strings;
33+
import org.opensearch.core.rest.RestStatus;
3334
import org.opensearch.ml.common.FunctionName;
3435
import org.opensearch.ml.common.input.Input;
3536
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
@@ -44,6 +45,7 @@
4445
import org.opensearch.test.OpenSearchTestCase;
4546
import org.opensearch.threadpool.TestThreadPool;
4647
import org.opensearch.threadpool.ThreadPool;
48+
import org.opensearch.transport.RemoteTransportException;
4749

4850
public class RestMLExecuteActionTests extends OpenSearchTestCase {
4951

@@ -206,4 +208,77 @@ public void testPrepareRequest_disabled() {
206208
when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false);
207209
assertThrows(IllegalStateException.class, () -> restMLExecuteAction.handleRequest(request, channel, client));
208210
}
211+
212+
public void testPrepareRequestClientException() throws Exception {
213+
doAnswer(invocation -> {
214+
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
215+
actionListener.onFailure(new IllegalArgumentException("Illegal Argument Exception"));
216+
return null;
217+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
218+
doNothing().when(channel).sendResponse(any());
219+
RestRequest request = getLocalSampleCalculatorRestRequest();
220+
restMLExecuteAction.handleRequest(request, channel, client);
221+
222+
ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
223+
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
224+
Input input = argumentCaptor.getValue().getInput();
225+
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
226+
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
227+
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
228+
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
229+
assertEquals(RestStatus.BAD_REQUEST, response.status());
230+
String content = response.content().utf8ToString();
231+
String expectedError =
232+
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
233+
assertEquals(expectedError, response.content().utf8ToString());
234+
}
235+
236+
public void testPrepareRequestClientWrappedException() throws Exception {
237+
doAnswer(invocation -> {
238+
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
239+
actionListener
240+
.onFailure(
241+
new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception"))
242+
);
243+
return null;
244+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
245+
doNothing().when(channel).sendResponse(any());
246+
RestRequest request = getLocalSampleCalculatorRestRequest();
247+
restMLExecuteAction.handleRequest(request, channel, client);
248+
249+
ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
250+
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
251+
Input input = argumentCaptor.getValue().getInput();
252+
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
253+
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
254+
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
255+
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
256+
assertEquals(RestStatus.BAD_REQUEST, response.status());
257+
String expectedError =
258+
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
259+
assertEquals(expectedError, response.content().utf8ToString());
260+
}
261+
262+
public void testPrepareRequestSystemException() throws Exception {
263+
doAnswer(invocation -> {
264+
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
265+
actionListener.onFailure(new RuntimeException("System Exception"));
266+
return null;
267+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
268+
doNothing().when(channel).sendResponse(any());
269+
RestRequest request = getLocalSampleCalculatorRestRequest();
270+
restMLExecuteAction.handleRequest(request, channel, client);
271+
272+
ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
273+
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
274+
Input input = argumentCaptor.getValue().getInput();
275+
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
276+
ArgumentCaptor<RestResponse> restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class);
277+
verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture());
278+
BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue();
279+
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, response.status());
280+
String expectedError =
281+
"{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}";
282+
assertEquals(expectedError, response.content().utf8ToString());
283+
}
209284
}

plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java

+21-23
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,41 @@
55

66
package org.opensearch.ml.utils.error;
77

8-
import static org.junit.Assert.assertFalse;
8+
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertTrue;
1010

1111
import org.junit.Test;
1212
import org.opensearch.OpenSearchException;
1313
import org.opensearch.core.rest.RestStatus;
14+
import org.opensearch.transport.RemoteTransportException;
1415

1516
public class ErrorMessageFactoryTests {
1617

17-
private Throwable nonOpenSearchThrowable = new Throwable();
18-
private Throwable openSearchThrowable = new OpenSearchException(nonOpenSearchThrowable);
19-
20-
@Test
21-
public void openSearchExceptionShouldCreateEsErrorMessage() {
22-
Exception exception = new OpenSearchException(nonOpenSearchThrowable);
23-
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
24-
assertTrue(msg.exception instanceof OpenSearchException);
25-
}
26-
2718
@Test
28-
public void nonOpenSearchExceptionShouldCreateGenericErrorMessage() {
29-
Exception exception = new Exception(nonOpenSearchThrowable);
30-
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
31-
assertFalse(msg.exception instanceof OpenSearchException);
19+
public void openSearchExceptionWithoutNestedException() {
20+
Throwable openSearchThrowable = new OpenSearchException("OpenSearch Exception");
21+
ErrorMessage errorMessage = ErrorMessageFactory.createErrorMessage(openSearchThrowable, RestStatus.BAD_REQUEST.getStatus());
22+
assertTrue(errorMessage.exception instanceof OpenSearchException);
23+
assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus());
3224
}
3325

3426
@Test
35-
public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
36-
Exception exception = (Exception) openSearchThrowable;
37-
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
38-
assertTrue(msg.exception instanceof OpenSearchException);
27+
public void openSearchExceptionWithNestedException() {
28+
Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception");
29+
Throwable openSearchThrowable = new RemoteTransportException("Remote Transport Exception", nestedThrowable);
30+
ErrorMessage errorMessage = ErrorMessageFactory
31+
.createErrorMessage(openSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus());
32+
assertTrue(errorMessage.exception instanceof IllegalArgumentException);
33+
assertEquals(RestStatus.BAD_REQUEST.getStatus(), errorMessage.getStatus());
3934
}
4035

4136
@Test
42-
public void nonOpenSearchExceptionWithMultiLayerWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
43-
Exception exception = new Exception(new Throwable(new Throwable(openSearchThrowable)));
44-
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
45-
assertTrue(msg.exception instanceof OpenSearchException);
37+
public void nonOpenSearchExceptionWithNestedException() {
38+
Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception");
39+
Throwable nonOpenSearchThrowable = new Exception("Remote Transport Exception", nestedThrowable);
40+
ErrorMessage errorMessage = ErrorMessageFactory
41+
.createErrorMessage(nonOpenSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus());
42+
assertTrue(errorMessage.exception instanceof IllegalArgumentException);
43+
assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus());
4644
}
4745
}

0 commit comments

Comments
 (0)