8
8
import static org .junit .Assert .assertEquals ;
9
9
import static org .junit .Assert .assertFalse ;
10
10
import static org .junit .Assert .assertTrue ;
11
+ import static org .junit .Assert .fail ;
11
12
import static org .mockito .Answers .RETURNS_DEEP_STUBS ;
12
13
import static org .mockito .ArgumentMatchers .any ;
13
14
import static org .mockito .ArgumentMatchers .eq ;
@@ -325,6 +326,64 @@ public void train() {
325
326
assertEquals (status , ((MLTrainingOutput ) argumentCaptor .getValue ()).getStatus ());
326
327
}
327
328
329
+ @ Test
330
+ public void getModel_withTenantId () {
331
+ String modelContent = "test content" ;
332
+ String tenantId = "tenantId" ;
333
+ doAnswer (invocation -> {
334
+ ActionListener <MLModelGetResponse > actionListener = invocation .getArgument (2 );
335
+ MLModel mlModel = MLModel .builder ().algorithm (FunctionName .KMEANS ).name ("test" ).content (modelContent ).build ();
336
+ MLModelGetResponse output = MLModelGetResponse .builder ().mlModel (mlModel ).build ();
337
+ actionListener .onResponse (output );
338
+ return null ;
339
+ }).when (client ).execute (eq (MLModelGetAction .INSTANCE ), any (), any ());
340
+
341
+ ArgumentCaptor <MLModel > argumentCaptor = ArgumentCaptor .forClass (MLModel .class );
342
+ machineLearningNodeClient .getModel ("modelId" , tenantId , getModelActionListener );
343
+
344
+ verify (client ).execute (eq (MLModelGetAction .INSTANCE ), isA (MLModelGetRequest .class ), any ());
345
+ verify (getModelActionListener ).onResponse (argumentCaptor .capture ());
346
+ assertEquals (FunctionName .KMEANS , argumentCaptor .getValue ().getAlgorithm ());
347
+ assertEquals (modelContent , argumentCaptor .getValue ().getContent ());
348
+ }
349
+
350
+ @ Test
351
+ public void undeployModels_withNullNodeIds () {
352
+ doAnswer (invocation -> {
353
+ ActionListener <MLUndeployModelsResponse > actionListener = invocation .getArgument (2 );
354
+ MLUndeployModelsResponse output = new MLUndeployModelsResponse (
355
+ new MLUndeployModelNodesResponse (ClusterName .DEFAULT , Collections .emptyList (), Collections .emptyList ())
356
+ );
357
+ actionListener .onResponse (output );
358
+ return null ;
359
+ }).when (client ).execute (eq (MLUndeployModelsAction .INSTANCE ), any (), any ());
360
+
361
+ machineLearningNodeClient .undeploy (new String [] { "model1" }, null , undeployModelsActionListener );
362
+ verify (client ).execute (eq (MLUndeployModelsAction .INSTANCE ), isA (MLUndeployModelsRequest .class ), any ());
363
+ }
364
+
365
+ @ Test
366
+ public void createConnector_withValidInput () {
367
+ doAnswer (invocation -> {
368
+ ActionListener <MLCreateConnectorResponse > actionListener = invocation .getArgument (2 );
369
+ MLCreateConnectorResponse output = new MLCreateConnectorResponse ("connectorId" );
370
+ actionListener .onResponse (output );
371
+ return null ;
372
+ }).when (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), any (), any ());
373
+
374
+ MLCreateConnectorInput input = MLCreateConnectorInput
375
+ .builder ()
376
+ .name ("testConnector" )
377
+ .protocol ("http" )
378
+ .version ("1" )
379
+ .credential (Map .of ("TEST_CREDENTIAL_KEY" , "TEST_CREDENTIAL_VALUE" ))
380
+ .parameters (Map .of ("endpoint" , "https://example.com" ))
381
+ .build ();
382
+
383
+ machineLearningNodeClient .createConnector (input , createConnectorActionListener );
384
+ verify (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), isA (MLCreateConnectorRequest .class ), any ());
385
+ }
386
+
328
387
@ Test
329
388
public void registerModelGroup_withValidInput () {
330
389
doAnswer (invocation -> {
@@ -346,6 +405,146 @@ public void registerModelGroup_withValidInput() {
346
405
verify (client ).execute (eq (MLRegisterModelGroupAction .INSTANCE ), isA (MLRegisterModelGroupRequest .class ), any ());
347
406
}
348
407
408
+ @ Test
409
+ public void listTools_withValidRequest () {
410
+ doAnswer (invocation -> {
411
+ ActionListener <MLToolsListResponse > actionListener = invocation .getArgument (2 );
412
+ MLToolsListResponse output = MLToolsListResponse
413
+ .builder ()
414
+ .toolMetadata (
415
+ Arrays
416
+ .asList (
417
+ ToolMetadata .builder ().name ("tool1" ).description ("description1" ).build (),
418
+ ToolMetadata .builder ().name ("tool2" ).description ("description2" ).build ()
419
+ )
420
+ )
421
+ .build ();
422
+ actionListener .onResponse (output );
423
+ return null ;
424
+ }).when (client ).execute (eq (MLListToolsAction .INSTANCE ), any (), any ());
425
+
426
+ machineLearningNodeClient .listTools (listToolsActionListener );
427
+ verify (client ).execute (eq (MLListToolsAction .INSTANCE ), isA (MLToolsListRequest .class ), any ());
428
+ }
429
+
430
+ @ Test
431
+ public void listTools_withEmptyResponse () {
432
+ doAnswer (invocation -> {
433
+ ActionListener <MLToolsListResponse > actionListener = invocation .getArgument (2 );
434
+ MLToolsListResponse output = MLToolsListResponse .builder ().toolMetadata (Collections .emptyList ()).build ();
435
+ actionListener .onResponse (output );
436
+ return null ;
437
+ }).when (client ).execute (eq (MLListToolsAction .INSTANCE ), any (), any ());
438
+
439
+ ArgumentCaptor <List <ToolMetadata >> argumentCaptor = ArgumentCaptor .forClass (List .class );
440
+ machineLearningNodeClient .listTools (listToolsActionListener );
441
+
442
+ verify (client ).execute (eq (MLListToolsAction .INSTANCE ), isA (MLToolsListRequest .class ), any ());
443
+ verify (listToolsActionListener ).onResponse (argumentCaptor .capture ());
444
+
445
+ List <ToolMetadata > capturedTools = argumentCaptor .getValue ();
446
+ assertTrue (capturedTools .isEmpty ());
447
+ }
448
+
449
+ @ Test
450
+ public void getTool_withValidToolName () {
451
+ doAnswer (invocation -> {
452
+ ActionListener <MLToolGetResponse > actionListener = invocation .getArgument (2 );
453
+ MLToolGetResponse output = MLToolGetResponse
454
+ .builder ()
455
+ .toolMetadata (ToolMetadata .builder ().name ("tool1" ).description ("description1" ).build ())
456
+ .build ();
457
+ actionListener .onResponse (output );
458
+ return null ;
459
+ }).when (client ).execute (eq (MLGetToolAction .INSTANCE ), any (), any ());
460
+
461
+ machineLearningNodeClient .getTool ("tool1" , getToolActionListener );
462
+ verify (client ).execute (eq (MLGetToolAction .INSTANCE ), isA (MLToolGetRequest .class ), any ());
463
+ }
464
+
465
+ @ Test
466
+ public void getTool_withValidRequest () {
467
+ ToolMetadata toolMetadata = ToolMetadata
468
+ .builder ()
469
+ .name ("MathTool" )
470
+ .description ("Use this tool to calculate any math problem." )
471
+ .build ();
472
+
473
+ doAnswer (invocation -> {
474
+ ActionListener <MLToolGetResponse > actionListener = invocation .getArgument (2 );
475
+ MLToolGetResponse output = MLToolGetResponse .builder ().toolMetadata (toolMetadata ).build ();
476
+ actionListener .onResponse (output );
477
+ return null ;
478
+ }).when (client ).execute (eq (MLGetToolAction .INSTANCE ), any (), any ());
479
+
480
+ ArgumentCaptor <ToolMetadata > argumentCaptor = ArgumentCaptor .forClass (ToolMetadata .class );
481
+ machineLearningNodeClient .getTool ("MathTool" , getToolActionListener );
482
+
483
+ verify (client ).execute (eq (MLGetToolAction .INSTANCE ), isA (MLToolGetRequest .class ), any ());
484
+ verify (getToolActionListener ).onResponse (argumentCaptor .capture ());
485
+
486
+ ToolMetadata capturedTool = argumentCaptor .getValue ();
487
+ assertEquals ("MathTool" , capturedTool .getName ());
488
+ assertEquals ("Use this tool to calculate any math problem." , capturedTool .getDescription ());
489
+ }
490
+
491
+ @ Test
492
+ public void getTool_withFailureResponse () {
493
+ doAnswer (invocation -> {
494
+ ActionListener <MLToolGetResponse > actionListener = invocation .getArgument (2 );
495
+ actionListener .onFailure (new RuntimeException ("Test exception" ));
496
+ return null ;
497
+ }).when (client ).execute (eq (MLGetToolAction .INSTANCE ), any (), any ());
498
+
499
+ machineLearningNodeClient .getTool ("MathTool" , new ActionListener <>() {
500
+ @ Override
501
+ public void onResponse (ToolMetadata toolMetadata ) {
502
+ fail ("Expected failure but got response" );
503
+ }
504
+
505
+ @ Override
506
+ public void onFailure (Exception e ) {
507
+ assertEquals ("Test exception" , e .getMessage ());
508
+ }
509
+ });
510
+
511
+ verify (client ).execute (eq (MLGetToolAction .INSTANCE ), isA (MLToolGetRequest .class ), any ());
512
+ }
513
+
514
+ @ Test
515
+ public void train_withAsync () {
516
+ doAnswer (invocation -> {
517
+ ActionListener <MLTaskResponse > actionListener = invocation .getArgument (2 );
518
+ MLTrainingOutput output = MLTrainingOutput .builder ().status ("InProgress" ).modelId ("modelId" ).build ();
519
+ actionListener .onResponse (MLTaskResponse .builder ().output (output ).build ());
520
+ return null ;
521
+ }).when (client ).execute (eq (MLTrainingTaskAction .INSTANCE ), any (), any ());
522
+
523
+ MLInput mlInput = MLInput .builder ().algorithm (FunctionName .KMEANS ).inputDataset (input ).build ();
524
+ machineLearningNodeClient .train (mlInput , true , trainingActionListener );
525
+ verify (client ).execute (eq (MLTrainingTaskAction .INSTANCE ), isA (MLTrainingTaskRequest .class ), any ());
526
+ }
527
+
528
+ @ Test
529
+ public void deleteModel_withTenantId () {
530
+ String modelId = "testModelId" ;
531
+ String tenantId = "tenantId" ;
532
+ doAnswer (invocation -> {
533
+ ActionListener <DeleteResponse > actionListener = invocation .getArgument (2 );
534
+ ShardId shardId = new ShardId (new Index ("indexName" , "uuid" ), 1 );
535
+ DeleteResponse output = new DeleteResponse (shardId , modelId , 1 , 1 , 1 , true );
536
+ actionListener .onResponse (output );
537
+ return null ;
538
+ }).when (client ).execute (eq (MLModelDeleteAction .INSTANCE ), any (), any ());
539
+
540
+ ArgumentCaptor <DeleteResponse > argumentCaptor = ArgumentCaptor .forClass (DeleteResponse .class );
541
+ machineLearningNodeClient .deleteModel (modelId , tenantId , deleteModelActionListener );
542
+
543
+ verify (client ).execute (eq (MLModelDeleteAction .INSTANCE ), isA (MLModelDeleteRequest .class ), any ());
544
+ verify (deleteModelActionListener ).onResponse (argumentCaptor .capture ());
545
+ assertEquals (modelId , argumentCaptor .getValue ().getId ());
546
+ }
547
+
349
548
@ Test
350
549
public void train_Exception_WithNullDataSet () {
351
550
exceptionRule .expect (IllegalArgumentException .class );
0 commit comments