6
6
package org .opensearch .ml .rest ;
7
7
8
8
import static org .mockito .ArgumentMatchers .any ;
9
+ import static org .mockito .ArgumentMatchers .argThat ;
9
10
import static org .mockito .ArgumentMatchers .eq ;
10
11
import static org .mockito .Mockito .doAnswer ;
11
12
import static org .mockito .Mockito .spy ;
51
52
import org .opensearch .core .common .Strings ;
52
53
import org .opensearch .core .common .bytes .BytesReference ;
53
54
import org .opensearch .core .common .transport .TransportAddress ;
55
+ import org .opensearch .core .rest .RestStatus ;
54
56
import org .opensearch .core .xcontent .NamedXContentRegistry ;
55
57
import org .opensearch .core .xcontent .XContentBuilder ;
56
58
import org .opensearch .ml .action .profile .MLProfileAction ;
67
69
import org .opensearch .ml .profile .MLModelProfile ;
68
70
import org .opensearch .ml .profile .MLPredictRequestStats ;
69
71
import org .opensearch .ml .profile .MLProfileInput ;
72
+ import org .opensearch .rest .BytesRestResponse ;
70
73
import org .opensearch .rest .RestChannel ;
71
74
import org .opensearch .rest .RestHandler ;
72
75
import org .opensearch .rest .RestRequest ;
@@ -151,13 +154,6 @@ public void setup() throws IOException {
151
154
testState = setupTestClusterState ();
152
155
when (clusterService .state ()).thenReturn (testState );
153
156
154
- doAnswer (invocation -> {
155
- ActionListener <SearchResponse > listener = invocation .getArgument (1 );
156
- SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
157
- listener .onResponse (response );
158
- return null ;
159
- }).when (client ).search (any (SearchRequest .class ), any ());
160
-
161
157
doAnswer (invocation -> {
162
158
ActionListener <MLProfileResponse > actionListener = invocation .getArgument (2 );
163
159
Map <String , MLTask > nodeTasks = new HashMap <>();
@@ -207,6 +203,13 @@ public void testRoutes() {
207
203
}
208
204
209
205
public void test_PrepareRequest_TaskRequest () throws Exception {
206
+ doAnswer (invocation -> {
207
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
208
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
209
+ listener .onResponse (response );
210
+ return null ;
211
+ }).when (client ).search (any (SearchRequest .class ), any ());
212
+
210
213
RestRequest request = getRestRequest ();
211
214
profileAction .handleRequest (request , channel , client );
212
215
@@ -218,6 +221,13 @@ public void test_PrepareRequest_TaskRequest() throws Exception {
218
221
}
219
222
220
223
public void test_PrepareRequest_TaskRequestWithNoTaskIds () throws Exception {
224
+ doAnswer (invocation -> {
225
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
226
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
227
+ listener .onResponse (response );
228
+ return null ;
229
+ }).when (client ).search (any (SearchRequest .class ), any ());
230
+
221
231
RestRequest request = new FakeRestRequest .Builder (NamedXContentRegistry .EMPTY ).withPath ("/_plugins/_ml/profile/tasks" ).build ();
222
232
profileAction .handleRequest (request , channel , client );
223
233
@@ -228,6 +238,13 @@ public void test_PrepareRequest_TaskRequestWithNoTaskIds() throws Exception {
228
238
}
229
239
230
240
public void test_PrepareRequest_ModelRequest () throws Exception {
241
+ doAnswer (invocation -> {
242
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
243
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
244
+ listener .onResponse (response );
245
+ return null ;
246
+ }).when (client ).search (any (SearchRequest .class ), any ());
247
+
231
248
RestRequest request = getModelRestRequest ();
232
249
profileAction .handleRequest (request , channel , client );
233
250
@@ -239,6 +256,13 @@ public void test_PrepareRequest_ModelRequest() throws Exception {
239
256
}
240
257
241
258
public void test_PrepareRequest_TaskRequestWithNoModelIds () throws Exception {
259
+ doAnswer (invocation -> {
260
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
261
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
262
+ listener .onResponse (response );
263
+ return null ;
264
+ }).when (client ).search (any (SearchRequest .class ), any ());
265
+
242
266
RestRequest request = new FakeRestRequest .Builder (NamedXContentRegistry .EMPTY ).withPath ("/_plugins/_ml/profile/models" ).build ();
243
267
profileAction .handleRequest (request , channel , client );
244
268
@@ -249,6 +273,12 @@ public void test_PrepareRequest_TaskRequestWithNoModelIds() throws Exception {
249
273
}
250
274
251
275
public void test_PrepareRequest_EmptyNodeProfile () throws Exception {
276
+ doAnswer (invocation -> {
277
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
278
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
279
+ listener .onResponse (response );
280
+ return null ;
281
+ }).when (client ).search (any (SearchRequest .class ), any ());
252
282
doAnswer (invocation -> {
253
283
ActionListener <MLProfileResponse > actionListener = invocation .getArgument (2 );
254
284
MLProfileResponse profileResponse = new MLProfileResponse (clusterName , new ArrayList <>(), new ArrayList <>());
@@ -267,6 +297,13 @@ public void test_PrepareRequest_EmptyNodeProfile() throws Exception {
267
297
}
268
298
269
299
public void test_PrepareRequest_EmptyNodeTasksSize () throws Exception {
300
+ doAnswer (invocation -> {
301
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
302
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
303
+ listener .onResponse (response );
304
+ return null ;
305
+ }).when (client ).search (any (SearchRequest .class ), any ());
306
+
270
307
doAnswer (invocation -> {
271
308
ActionListener <MLProfileResponse > actionListener = invocation .getArgument (2 );
272
309
Map <String , MLTask > nodeTasks = new HashMap <>();
@@ -288,6 +325,13 @@ public void test_PrepareRequest_EmptyNodeTasksSize() throws Exception {
288
325
}
289
326
290
327
public void test_PrepareRequest_WithRequestContent () throws Exception {
328
+ doAnswer (invocation -> {
329
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
330
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
331
+ listener .onResponse (response );
332
+ return null ;
333
+ }).when (client ).search (any (SearchRequest .class ), any ());
334
+
291
335
MLProfileInput mlProfileInput = new MLProfileInput ();
292
336
RestRequest request = getProfileRestRequest (mlProfileInput );
293
337
profileAction .handleRequest (request , channel , client );
@@ -296,6 +340,13 @@ public void test_PrepareRequest_WithRequestContent() throws Exception {
296
340
}
297
341
298
342
public void test_PrepareRequest_Failure () throws Exception {
343
+ doAnswer (invocation -> {
344
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
345
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
346
+ listener .onResponse (response );
347
+ return null ;
348
+ }).when (client ).search (any (SearchRequest .class ), any ());
349
+
299
350
doAnswer (invocation -> {
300
351
ActionListener <MLProfileResponse > actionListener = invocation .getArgument (2 );
301
352
actionListener .onFailure (new RuntimeException ("test failure" ));
@@ -308,14 +359,84 @@ public void test_PrepareRequest_Failure() throws Exception {
308
359
verify (client , times (1 )).execute (eq (MLProfileAction .INSTANCE ), argumentCaptor .capture (), any ());
309
360
}
310
361
362
+ public void test_Search_Failure () throws Exception {
363
+ // Setup to simulate a search failure
364
+ doAnswer (invocation -> {
365
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
366
+ listener .onFailure (new Exception ("Mocking Exception" )); // Trigger failure
367
+ return null ;
368
+ }).when (client ).search (any (SearchRequest .class ), any (ActionListener .class ));
369
+
370
+ // Create a RestRequest instance for testing
371
+ RestRequest request = getRestRequest (); // Ensure this method correctly initializes a RestRequest
372
+
373
+ // Handle the request with the expectation of handling a failure
374
+ profileAction .handleRequest (request , channel , client );
375
+
376
+ // Verification that the search method was called exactly once
377
+ verify (client , times (1 )).search (any (SearchRequest .class ), any (ActionListener .class ));
378
+
379
+ // Capturing the response sent to the channel
380
+ ArgumentCaptor <BytesRestResponse > responseCaptor = ArgumentCaptor .forClass (BytesRestResponse .class );
381
+ verify (channel ).sendResponse (responseCaptor .capture ());
382
+
383
+ // Check the response status code to see if it correctly reflects the error
384
+ BytesRestResponse response = responseCaptor .getValue ();
385
+ assertEquals (RestStatus .OK , response .status ());
386
+ assertTrue (response .content ().utf8ToString ().contains ("{}" ));
387
+ }
388
+
311
389
public void test_WhenViewIsModel_ReturnModelViewResult () throws Exception {
390
+ doAnswer (invocation -> {
391
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
392
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
393
+ listener .onResponse (response );
394
+ return null ;
395
+ }).when (client ).search (any (SearchRequest .class ), any ());
312
396
MLProfileInput mlProfileInput = new MLProfileInput ();
313
397
RestRequest request = getProfileRestRequestWithQueryParams (mlProfileInput , ImmutableMap .of ("view" , "model" ));
314
398
profileAction .handleRequest (request , channel , client );
315
399
ArgumentCaptor <MLProfileRequest > argumentCaptor = ArgumentCaptor .forClass (MLProfileRequest .class );
316
400
verify (client , times (1 )).execute (eq (MLProfileAction .INSTANCE ), argumentCaptor .capture (), any ());
317
401
}
318
402
403
+ // public void testNodeViewOutput() throws Exception {
404
+ // // Assuming setup for non-empty node responses as done in the initial setup
405
+ // MLProfileInput mlProfileInput = new MLProfileInput();
406
+ // RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "node"));
407
+ // profileAction.handleRequest(request, channel, client);
408
+ //
409
+ // ArgumentCaptor<MLProfileRequest> argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class);
410
+ // verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
411
+ //
412
+ // // Verify that the response is correctly formed for the node view
413
+ // verify(channel).sendResponse(argThat(response -> {
414
+ // // Ensure the response content matches expected node view structure
415
+ // String content = response.content().utf8ToString();
416
+ // return content.contains("\"node\":") && !content.contains("\"models\":");
417
+ // }));
418
+ // }
419
+
420
+ public void testBackendFailureHandling () throws Exception {
421
+ doAnswer (invocation -> {
422
+ ActionListener <SearchResponse > listener = invocation .getArgument (1 );
423
+ SearchResponse response = createSearchModelResponse (); // Prepare your mocked response here
424
+ listener .onResponse (response );
425
+ return null ;
426
+ }).when (client ).search (any (SearchRequest .class ), any ());
427
+
428
+ doAnswer (invocation -> {
429
+ ActionListener <MLProfileResponse > listener = invocation .getArgument (2 );
430
+ listener .onFailure (new RuntimeException ("Simulated backend failure" ));
431
+ return null ;
432
+ }).when (client ).execute (eq (MLProfileAction .INSTANCE ), any (MLProfileRequest .class ), any (ActionListener .class ));
433
+
434
+ RestRequest request = getRestRequest ();
435
+ profileAction .handleRequest (request , channel , client );
436
+
437
+ verify (channel ).sendResponse (argThat (response -> response .status () == RestStatus .INTERNAL_SERVER_ERROR ));
438
+ }
439
+
319
440
private SearchResponse createSearchModelResponse () throws IOException {
320
441
XContentBuilder content = builder ();
321
442
content .startObject ();
0 commit comments