Skip to content

Commit d265fc7

Browse files
authored
Added create connector API for MLClient (opensearch-project#1437)
* Adds create connector API for MLClient Signed-off-by: Owais Kazi <owaiskazi19@gmail.com> * Addressed PR comments Signed-off-by: Owais Kazi <owaiskazi19@gmail.com> * Addressed PR Comments Signed-off-by: Owais Kazi <owaiskazi19@gmail.com> --------- Signed-off-by: Owais Kazi <owaiskazi19@gmail.com>
1 parent 89f9b85 commit d265fc7

File tree

4 files changed

+115
-7
lines changed

4 files changed

+115
-7
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import org.opensearch.ml.common.MLTask;
1717
import org.opensearch.ml.common.input.MLInput;
1818
import org.opensearch.ml.common.output.MLOutput;
19+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
20+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
1921
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
2022
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
2123
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
@@ -267,4 +269,17 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
267269
* @param listener a listener to be notified of the result
268270
*/
269271
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);
272+
273+
/**
274+
* Create connector for remote model
275+
* @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/
276+
* @return the result future
277+
*/
278+
default ActionFuture<MLCreateConnectorResponse> createConnector(MLCreateConnectorInput mlCreateConnectorInput) {
279+
PlainActionFuture<MLCreateConnectorResponse> actionFuture = PlainActionFuture.newFuture();
280+
createConnector(mlCreateConnectorInput, actionFuture);
281+
return actionFuture;
282+
}
283+
284+
void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener);
270285
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
3333
import org.opensearch.ml.common.output.MLOutput;
3434
import org.opensearch.ml.common.transport.MLTaskResponse;
35+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
36+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
37+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
38+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
3539
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
3640
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
3741
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
@@ -228,6 +232,12 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
228232
}));
229233
}
230234

235+
@Override
236+
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
237+
MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput);
238+
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, listener);
239+
}
240+
231241
private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
232242
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
233243
listener.onResponse(predictionResponse.getOutput());

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+33
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.action.delete.DeleteResponse;
1616
import org.opensearch.action.search.SearchRequest;
1717
import org.opensearch.action.search.SearchResponse;
18+
import org.opensearch.ml.common.AccessMode;
1819
import org.opensearch.ml.common.dataframe.DataFrame;
1920
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
2021
import org.opensearch.ml.common.input.MLInput;
@@ -27,6 +28,8 @@
2728
import org.opensearch.ml.common.output.MLOutput;
2829
import org.opensearch.ml.common.MLTask;
2930
import org.opensearch.ml.common.output.MLTrainingOutput;
31+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
32+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
3033
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
3134
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
3235
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
@@ -74,6 +77,9 @@ public class MachineLearningClientTest {
7477
@Mock
7578
MLDeployModelResponse deployModelResponse;
7679

80+
@Mock
81+
MLCreateConnectorResponse createConnectorResponse;
82+
7783
private String modekId = "test_model_id";
7884
private MLModel mlModel;
7985
private MLTask mlTask;
@@ -158,6 +164,11 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
158164
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
159165
listener.onResponse(deployModelResponse);
160166
}
167+
168+
@Override
169+
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
170+
listener.onResponse(createConnectorResponse);
171+
}
161172
};
162173
}
163174

@@ -304,4 +315,26 @@ public void register() {
304315
public void deploy() {
305316
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
306317
}
318+
319+
@Test
320+
public void createConnector() {
321+
Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
322+
Map<String, String> credentials = Map.ofEntries(Map.entry("key1", "key1"), Map.entry("key2", "key2"));
323+
324+
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder()
325+
.name("test")
326+
.description("description")
327+
.version("testModelVersion")
328+
.protocol("testProtocol")
329+
.parameters(params)
330+
.credential(credentials)
331+
.actions(null)
332+
.backendRoles(null)
333+
.addAllBackendRoles(false)
334+
.access(AccessMode.from("private"))
335+
.dryRun(false)
336+
.build();
337+
338+
assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet());
339+
}
307340
}

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+57-7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.opensearch.ml.common.MLModel;
3232
import org.opensearch.ml.common.MLTask;
3333
import org.opensearch.ml.common.MLTaskState;
34+
import org.opensearch.ml.common.AccessMode;
3435
import org.opensearch.ml.common.MLTaskType;
3536
import org.opensearch.ml.common.dataframe.DataFrame;
3637
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -42,6 +43,10 @@
4243
import org.opensearch.ml.common.output.MLPredictionOutput;
4344
import org.opensearch.ml.common.output.MLTrainingOutput;
4445
import org.opensearch.ml.common.transport.MLTaskResponse;
46+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
47+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
48+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
49+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
4550
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
4651
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
4752
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
@@ -77,6 +82,8 @@
7782
import java.util.Collections;
7883
import java.util.HashMap;
7984
import java.util.Map;
85+
import java.util.Arrays;
86+
import java.util.List;
8087

8188
import static org.junit.Assert.assertEquals;
8289
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
@@ -121,10 +128,13 @@ public class MachineLearningNodeClientTest {
121128
ActionListener<SearchResponse> searchTaskActionListener;
122129

123130
@Mock
124-
ActionListener<MLRegisterModelResponse> RegisterModelActionListener;
131+
ActionListener<MLRegisterModelResponse> registerModelActionListener;
125132

126133
@Mock
127-
ActionListener<MLDeployModelResponse> DeployModelActionListener;
134+
ActionListener<MLDeployModelResponse> deployModelActionListener;
135+
136+
@Mock
137+
ActionListener<MLCreateConnectorResponse> createConnectorActionListener;
128138

129139
@InjectMocks
130140
MachineLearningNodeClient machineLearningNodeClient;
@@ -601,10 +611,10 @@ public void register() {
601611
.deployModel(true)
602612
.modelNodeIds(new String[]{"modelNodeIds" })
603613
.build();
604-
machineLearningNodeClient.register(mlInput, RegisterModelActionListener);
614+
machineLearningNodeClient.register(mlInput, registerModelActionListener);
605615

606616
verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any());
607-
verify(RegisterModelActionListener).onResponse(argumentCaptor.capture());
617+
verify(registerModelActionListener).onResponse(argumentCaptor.capture());
608618
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
609619
assertEquals(status, (argumentCaptor.getValue()).getStatus());
610620
}
@@ -615,7 +625,6 @@ public void deploy() {
615625
String status = MLTaskState.CREATED.name();
616626
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;
617627
String modelId = "modelId";
618-
FunctionName functionName = FunctionName.KMEANS;
619628
doAnswer(invocation -> {
620629
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
621630
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
@@ -624,14 +633,55 @@ public void deploy() {
624633
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());
625634

626635
ArgumentCaptor<MLDeployModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class);
627-
machineLearningNodeClient.deploy(modelId, DeployModelActionListener);
636+
machineLearningNodeClient.deploy(modelId, deployModelActionListener);
628637

629638
verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any());
630-
verify(DeployModelActionListener).onResponse(argumentCaptor.capture());
639+
verify(deployModelActionListener).onResponse(argumentCaptor.capture());
631640
assertEquals(taskId, (argumentCaptor.getValue()).getTaskId());
632641
assertEquals(status, (argumentCaptor.getValue()).getStatus());
633642
}
634643

644+
@Test
645+
public void createConnector() {
646+
647+
648+
String connectorId = "connectorId";
649+
650+
doAnswer(invocation -> {
651+
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2);
652+
MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId);
653+
actionListener.onResponse(output);
654+
return null;
655+
}).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any());
656+
657+
ArgumentCaptor<MLCreateConnectorResponse> argumentCaptor = ArgumentCaptor.forClass(MLCreateConnectorResponse.class);
658+
659+
Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
660+
Map<String, String> credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2"));
661+
List<String> backendRoles = Arrays.asList("IT", "HR");
662+
663+
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder()
664+
.name("test")
665+
.description("description")
666+
.version("testModelVersion")
667+
.protocol("testProtocol")
668+
.parameters(params)
669+
.credential(credentials)
670+
.actions(null)
671+
.backendRoles(backendRoles)
672+
.addAllBackendRoles(false)
673+
.access(AccessMode.from("private"))
674+
.dryRun(false)
675+
.build();
676+
677+
machineLearningNodeClient.createConnector(mlCreateConnectorInput, createConnectorActionListener);
678+
679+
verify(client).execute(eq(MLCreateConnectorAction.INSTANCE), isA(MLCreateConnectorRequest.class), any());
680+
verify(createConnectorActionListener).onResponse(argumentCaptor.capture());
681+
assertEquals(connectorId, (argumentCaptor.getValue()).getConnectorId());
682+
683+
}
684+
635685
private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
636686
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
637687

0 commit comments

Comments
 (0)