Skip to content

Commit b591c2c

Browse files
opensearch-trigger-bot[bot]dhrubo-os
andauthoredApr 23, 2024
not sending failure message when model index isn't present (opensearch-project#2351) (opensearch-project#2353)
* not sending failure message when model index isn't present Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * making profile api experience same Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * add unit test Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * applying spotless Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit be05dfc) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 83acad3 commit b591c2c

File tree

3 files changed

+142
-9
lines changed

3 files changed

+142
-9
lines changed
 

‎plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,15 @@ public void onResponse(SearchResponse searchResponse) {
154154

155155
@Override
156156
public void onFailure(Exception e) {
157-
onFailed(channel, "Searching model wasn't successful", e);
157+
try {
158+
builder.startObject();
159+
builder.endObject();
160+
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
161+
} catch (IOException ex) {
162+
String errorMessage = "Failed to get ML node level profile";
163+
log.error(errorMessage, e);
164+
onFailed(channel, errorMessage, e);
165+
}
158166
}
159167

160168
}, threadContext::restore));

‎plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,11 @@ public void onResponse(SearchResponse searchResponse) {
176176

177177
@Override
178178
public void onFailure(Exception e) {
179-
onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Searching model wasn't successful", e);
179+
try {
180+
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
181+
} catch (IOException ex) {
182+
onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to retrieve Cluster level metrics", e);
183+
}
180184
}
181185
}, threadContext::restore));
182186

‎plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java

+128-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.rest;
77

88
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.ArgumentMatchers.argThat;
910
import static org.mockito.ArgumentMatchers.eq;
1011
import static org.mockito.Mockito.doAnswer;
1112
import static org.mockito.Mockito.spy;
@@ -51,6 +52,7 @@
5152
import org.opensearch.core.common.Strings;
5253
import org.opensearch.core.common.bytes.BytesReference;
5354
import org.opensearch.core.common.transport.TransportAddress;
55+
import org.opensearch.core.rest.RestStatus;
5456
import org.opensearch.core.xcontent.NamedXContentRegistry;
5557
import org.opensearch.core.xcontent.XContentBuilder;
5658
import org.opensearch.ml.action.profile.MLProfileAction;
@@ -67,6 +69,7 @@
6769
import org.opensearch.ml.profile.MLModelProfile;
6870
import org.opensearch.ml.profile.MLPredictRequestStats;
6971
import org.opensearch.ml.profile.MLProfileInput;
72+
import org.opensearch.rest.BytesRestResponse;
7073
import org.opensearch.rest.RestChannel;
7174
import org.opensearch.rest.RestHandler;
7275
import org.opensearch.rest.RestRequest;
@@ -151,13 +154,6 @@ public void setup() throws IOException {
151154
testState = setupTestClusterState();
152155
when(clusterService.state()).thenReturn(testState);
153156

154-
doAnswer(invocation -> {
155-
ActionListener<SearchResponse> listener = invocation.getArgument(1);
156-
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
157-
listener.onResponse(response);
158-
return null;
159-
}).when(client).search(any(SearchRequest.class), any());
160-
161157
doAnswer(invocation -> {
162158
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
163159
Map<String, MLTask> nodeTasks = new HashMap<>();
@@ -207,6 +203,13 @@ public void testRoutes() {
207203
}
208204

209205
public void test_PrepareRequest_TaskRequest() throws Exception {
206+
doAnswer(invocation -> {
207+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
208+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
209+
listener.onResponse(response);
210+
return null;
211+
}).when(client).search(any(SearchRequest.class), any());
212+
210213
RestRequest request = getRestRequest();
211214
profileAction.handleRequest(request, channel, client);
212215

@@ -218,6 +221,13 @@ public void test_PrepareRequest_TaskRequest() throws Exception {
218221
}
219222

220223
public void test_PrepareRequest_TaskRequestWithNoTaskIds() throws Exception {
224+
doAnswer(invocation -> {
225+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
226+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
227+
listener.onResponse(response);
228+
return null;
229+
}).when(client).search(any(SearchRequest.class), any());
230+
221231
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/_plugins/_ml/profile/tasks").build();
222232
profileAction.handleRequest(request, channel, client);
223233

@@ -228,6 +238,13 @@ public void test_PrepareRequest_TaskRequestWithNoTaskIds() throws Exception {
228238
}
229239

230240
public void test_PrepareRequest_ModelRequest() throws Exception {
241+
doAnswer(invocation -> {
242+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
243+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
244+
listener.onResponse(response);
245+
return null;
246+
}).when(client).search(any(SearchRequest.class), any());
247+
231248
RestRequest request = getModelRestRequest();
232249
profileAction.handleRequest(request, channel, client);
233250

@@ -239,6 +256,13 @@ public void test_PrepareRequest_ModelRequest() throws Exception {
239256
}
240257

241258
public void test_PrepareRequest_TaskRequestWithNoModelIds() throws Exception {
259+
doAnswer(invocation -> {
260+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
261+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
262+
listener.onResponse(response);
263+
return null;
264+
}).when(client).search(any(SearchRequest.class), any());
265+
242266
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/_plugins/_ml/profile/models").build();
243267
profileAction.handleRequest(request, channel, client);
244268

@@ -249,6 +273,12 @@ public void test_PrepareRequest_TaskRequestWithNoModelIds() throws Exception {
249273
}
250274

251275
public void test_PrepareRequest_EmptyNodeProfile() throws Exception {
276+
doAnswer(invocation -> {
277+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
278+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
279+
listener.onResponse(response);
280+
return null;
281+
}).when(client).search(any(SearchRequest.class), any());
252282
doAnswer(invocation -> {
253283
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
254284
MLProfileResponse profileResponse = new MLProfileResponse(clusterName, new ArrayList<>(), new ArrayList<>());
@@ -267,6 +297,13 @@ public void test_PrepareRequest_EmptyNodeProfile() throws Exception {
267297
}
268298

269299
public void test_PrepareRequest_EmptyNodeTasksSize() throws Exception {
300+
doAnswer(invocation -> {
301+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
302+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
303+
listener.onResponse(response);
304+
return null;
305+
}).when(client).search(any(SearchRequest.class), any());
306+
270307
doAnswer(invocation -> {
271308
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
272309
Map<String, MLTask> nodeTasks = new HashMap<>();
@@ -288,6 +325,13 @@ public void test_PrepareRequest_EmptyNodeTasksSize() throws Exception {
288325
}
289326

290327
public void test_PrepareRequest_WithRequestContent() throws Exception {
328+
doAnswer(invocation -> {
329+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
330+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
331+
listener.onResponse(response);
332+
return null;
333+
}).when(client).search(any(SearchRequest.class), any());
334+
291335
MLProfileInput mlProfileInput = new MLProfileInput();
292336
RestRequest request = getProfileRestRequest(mlProfileInput);
293337
profileAction.handleRequest(request, channel, client);
@@ -296,6 +340,13 @@ public void test_PrepareRequest_WithRequestContent() throws Exception {
296340
}
297341

298342
public void test_PrepareRequest_Failure() throws Exception {
343+
doAnswer(invocation -> {
344+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
345+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
346+
listener.onResponse(response);
347+
return null;
348+
}).when(client).search(any(SearchRequest.class), any());
349+
299350
doAnswer(invocation -> {
300351
ActionListener<MLProfileResponse> actionListener = invocation.getArgument(2);
301352
actionListener.onFailure(new RuntimeException("test failure"));
@@ -308,14 +359,84 @@ public void test_PrepareRequest_Failure() throws Exception {
308359
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
309360
}
310361

362+
public void test_Search_Failure() throws Exception {
363+
// Setup to simulate a search failure
364+
doAnswer(invocation -> {
365+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
366+
listener.onFailure(new Exception("Mocking Exception")); // Trigger failure
367+
return null;
368+
}).when(client).search(any(SearchRequest.class), any(ActionListener.class));
369+
370+
// Create a RestRequest instance for testing
371+
RestRequest request = getRestRequest(); // Ensure this method correctly initializes a RestRequest
372+
373+
// Handle the request with the expectation of handling a failure
374+
profileAction.handleRequest(request, channel, client);
375+
376+
// Verification that the search method was called exactly once
377+
verify(client, times(1)).search(any(SearchRequest.class), any(ActionListener.class));
378+
379+
// Capturing the response sent to the channel
380+
ArgumentCaptor<BytesRestResponse> responseCaptor = ArgumentCaptor.forClass(BytesRestResponse.class);
381+
verify(channel).sendResponse(responseCaptor.capture());
382+
383+
// Check the response status code to see if it correctly reflects the error
384+
BytesRestResponse response = responseCaptor.getValue();
385+
assertEquals(RestStatus.OK, response.status());
386+
assertTrue(response.content().utf8ToString().contains("{}"));
387+
}
388+
311389
public void test_WhenViewIsModel_ReturnModelViewResult() throws Exception {
390+
doAnswer(invocation -> {
391+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
392+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
393+
listener.onResponse(response);
394+
return null;
395+
}).when(client).search(any(SearchRequest.class), any());
312396
MLProfileInput mlProfileInput = new MLProfileInput();
313397
RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "model"));
314398
profileAction.handleRequest(request, channel, client);
315399
ArgumentCaptor<MLProfileRequest> argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class);
316400
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
317401
}
318402

403+
// public void testNodeViewOutput() throws Exception {
404+
// // Assuming setup for non-empty node responses as done in the initial setup
405+
// MLProfileInput mlProfileInput = new MLProfileInput();
406+
// RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "node"));
407+
// profileAction.handleRequest(request, channel, client);
408+
//
409+
// ArgumentCaptor<MLProfileRequest> argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class);
410+
// verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
411+
//
412+
// // Verify that the response is correctly formed for the node view
413+
// verify(channel).sendResponse(argThat(response -> {
414+
// // Ensure the response content matches expected node view structure
415+
// String content = response.content().utf8ToString();
416+
// return content.contains("\"node\":") && !content.contains("\"models\":");
417+
// }));
418+
// }
419+
420+
public void testBackendFailureHandling() throws Exception {
421+
doAnswer(invocation -> {
422+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
423+
SearchResponse response = createSearchModelResponse(); // Prepare your mocked response here
424+
listener.onResponse(response);
425+
return null;
426+
}).when(client).search(any(SearchRequest.class), any());
427+
428+
doAnswer(invocation -> {
429+
ActionListener<MLProfileResponse> listener = invocation.getArgument(2);
430+
listener.onFailure(new RuntimeException("Simulated backend failure"));
431+
return null;
432+
}).when(client).execute(eq(MLProfileAction.INSTANCE), any(MLProfileRequest.class), any(ActionListener.class));
433+
434+
RestRequest request = getRestRequest();
435+
profileAction.handleRequest(request, channel, client);
436+
437+
verify(channel).sendResponse(argThat(response -> response.status() == RestStatus.INTERNAL_SERVER_ERROR));
438+
}
439+
319440
private SearchResponse createSearchModelResponse() throws IOException {
320441
XContentBuilder content = builder();
321442
content.startObject();

0 commit comments

Comments
 (0)