31
31
import org .opensearch .ml .common .MLModel ;
32
32
import org .opensearch .ml .common .MLTask ;
33
33
import org .opensearch .ml .common .MLTaskState ;
34
+ import org .opensearch .ml .common .AccessMode ;
34
35
import org .opensearch .ml .common .MLTaskType ;
35
36
import org .opensearch .ml .common .dataframe .DataFrame ;
36
37
import org .opensearch .ml .common .dataset .MLInputDataset ;
42
43
import org .opensearch .ml .common .output .MLPredictionOutput ;
43
44
import org .opensearch .ml .common .output .MLTrainingOutput ;
44
45
import org .opensearch .ml .common .transport .MLTaskResponse ;
46
+ import org .opensearch .ml .common .transport .connector .MLCreateConnectorAction ;
47
+ import org .opensearch .ml .common .transport .connector .MLCreateConnectorInput ;
48
+ import org .opensearch .ml .common .transport .connector .MLCreateConnectorRequest ;
49
+ import org .opensearch .ml .common .transport .connector .MLCreateConnectorResponse ;
45
50
import org .opensearch .ml .common .transport .deploy .MLDeployModelAction ;
46
51
import org .opensearch .ml .common .transport .deploy .MLDeployModelRequest ;
47
52
import org .opensearch .ml .common .transport .deploy .MLDeployModelResponse ;
77
82
import java .util .Collections ;
78
83
import java .util .HashMap ;
79
84
import java .util .Map ;
85
+ import java .util .Arrays ;
86
+ import java .util .List ;
80
87
81
88
import static org .junit .Assert .assertEquals ;
82
89
import static org .mockito .Answers .RETURNS_DEEP_STUBS ;
@@ -121,10 +128,13 @@ public class MachineLearningNodeClientTest {
121
128
ActionListener <SearchResponse > searchTaskActionListener ;
122
129
123
130
@ Mock
124
- ActionListener <MLRegisterModelResponse > RegisterModelActionListener ;
131
+ ActionListener <MLRegisterModelResponse > registerModelActionListener ;
125
132
126
133
@ Mock
127
- ActionListener <MLDeployModelResponse > DeployModelActionListener ;
134
+ ActionListener <MLDeployModelResponse > deployModelActionListener ;
135
+
136
+ @ Mock
137
+ ActionListener <MLCreateConnectorResponse > createConnectorActionListener ;
128
138
129
139
@ InjectMocks
130
140
MachineLearningNodeClient machineLearningNodeClient ;
@@ -601,10 +611,10 @@ public void register() {
601
611
.deployModel (true )
602
612
.modelNodeIds (new String []{"modelNodeIds" })
603
613
.build ();
604
- machineLearningNodeClient .register (mlInput , RegisterModelActionListener );
614
+ machineLearningNodeClient .register (mlInput , registerModelActionListener );
605
615
606
616
verify (client ).execute (eq (MLRegisterModelAction .INSTANCE ), isA (MLRegisterModelRequest .class ), any ());
607
- verify (RegisterModelActionListener ).onResponse (argumentCaptor .capture ());
617
+ verify (registerModelActionListener ).onResponse (argumentCaptor .capture ());
608
618
assertEquals (taskId , (argumentCaptor .getValue ()).getTaskId ());
609
619
assertEquals (status , (argumentCaptor .getValue ()).getStatus ());
610
620
}
@@ -615,7 +625,6 @@ public void deploy() {
615
625
String status = MLTaskState .CREATED .name ();
616
626
MLTaskType mlTaskType = MLTaskType .DEPLOY_MODEL ;
617
627
String modelId = "modelId" ;
618
- FunctionName functionName = FunctionName .KMEANS ;
619
628
doAnswer (invocation -> {
620
629
ActionListener <MLDeployModelResponse > actionListener = invocation .getArgument (2 );
621
630
MLDeployModelResponse output = new MLDeployModelResponse (taskId , mlTaskType , status );
@@ -624,14 +633,55 @@ public void deploy() {
624
633
}).when (client ).execute (eq (MLDeployModelAction .INSTANCE ), any (), any ());
625
634
626
635
ArgumentCaptor <MLDeployModelResponse > argumentCaptor = ArgumentCaptor .forClass (MLDeployModelResponse .class );
627
- machineLearningNodeClient .deploy (modelId , DeployModelActionListener );
636
+ machineLearningNodeClient .deploy (modelId , deployModelActionListener );
628
637
629
638
verify (client ).execute (eq (MLDeployModelAction .INSTANCE ), isA (MLDeployModelRequest .class ), any ());
630
- verify (DeployModelActionListener ).onResponse (argumentCaptor .capture ());
639
+ verify (deployModelActionListener ).onResponse (argumentCaptor .capture ());
631
640
assertEquals (taskId , (argumentCaptor .getValue ()).getTaskId ());
632
641
assertEquals (status , (argumentCaptor .getValue ()).getStatus ());
633
642
}
634
643
644
+ @ Test
645
+ public void createConnector () {
646
+
647
+
648
+ String connectorId = "connectorId" ;
649
+
650
+ doAnswer (invocation -> {
651
+ ActionListener <MLCreateConnectorResponse > actionListener = invocation .getArgument (2 );
652
+ MLCreateConnectorResponse output = new MLCreateConnectorResponse (connectorId );
653
+ actionListener .onResponse (output );
654
+ return null ;
655
+ }).when (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), any (), any ());
656
+
657
+ ArgumentCaptor <MLCreateConnectorResponse > argumentCaptor = ArgumentCaptor .forClass (MLCreateConnectorResponse .class );
658
+
659
+ Map <String , String > params = Map .ofEntries (Map .entry ("endpoint" , "endpoint" ), Map .entry ("temp" , "7" ));
660
+ Map <String , String > credentials = Map .ofEntries (Map .entry ("key1" , "value1" ), Map .entry ("key2" , "value2" ));
661
+ List <String > backendRoles = Arrays .asList ("IT" , "HR" );
662
+
663
+ MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput .builder ()
664
+ .name ("test" )
665
+ .description ("description" )
666
+ .version ("testModelVersion" )
667
+ .protocol ("testProtocol" )
668
+ .parameters (params )
669
+ .credential (credentials )
670
+ .actions (null )
671
+ .backendRoles (backendRoles )
672
+ .addAllBackendRoles (false )
673
+ .access (AccessMode .from ("private" ))
674
+ .dryRun (false )
675
+ .build ();
676
+
677
+ machineLearningNodeClient .createConnector (mlCreateConnectorInput , createConnectorActionListener );
678
+
679
+ verify (client ).execute (eq (MLCreateConnectorAction .INSTANCE ), isA (MLCreateConnectorRequest .class ), any ());
680
+ verify (createConnectorActionListener ).onResponse (argumentCaptor .capture ());
681
+ assertEquals (connectorId , (argumentCaptor .getValue ()).getConnectorId ());
682
+
683
+ }
684
+
635
685
private SearchResponse createSearchResponse (ToXContentObject o ) throws IOException {
636
686
XContentBuilder content = o .toXContent (XContentFactory .jsonBuilder (), ToXContent .EMPTY_PARAMS );
637
687
0 commit comments