Skip to content

Commit 4ea2cdd

Browse files
committed
adding multi-tenancy + sdk client related changes to model, model group and connector update (#3399)
* adding multi-tenancy + sdk client related changes to model, model group and connector update Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed more comments + refactored few codes Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent eb225b9 commit 4ea2cdd

File tree

92 files changed

+5372
-1960
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+5372
-1960
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+26-13
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,18 @@ default ActionFuture<MLModel> getModel(String modelId) {
142142
* @param modelId id of the model
143143
* @param listener action listener
144144
*/
145-
void getModel(String modelId, ActionListener<MLModel> listener);
145+
default void getModel(String modelId, ActionListener<MLModel> listener) {
146+
getModel(modelId, null, listener);
147+
}
148+
149+
/**
150+
* Get MLModel and return model in listener
151+
* For more info on get model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-model-information
152+
* @param modelId id of the model
153+
* @param tenantId id of the tenant
154+
* @param listener action listener
155+
*/
156+
void getModel(String modelId, String tenantId, ActionListener<MLModel> listener);
146157

147158
/**
148159
* Get MLTask and return ActionFuture.
@@ -182,7 +193,18 @@ default ActionFuture<DeleteResponse> deleteModel(String modelId) {
182193
* @param modelId id of the model
183194
* @param listener action listener
184195
*/
185-
void deleteModel(String modelId, ActionListener<DeleteResponse> listener);
196+
default void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
197+
deleteModel(modelId, null, listener);
198+
}
199+
200+
/**
201+
* Delete MLModel
202+
* For more info on delete model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#delete-model
203+
* @param modelId id of the model
204+
* @param tenantId the tenant id. This is necessary for multi-tenancy.
205+
* @param listener action listener
206+
*/
207+
void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener);
186208

187209
/**
188210
* Delete the task with taskId.
@@ -323,19 +345,10 @@ default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
323345
return actionFuture;
324346
}
325347

326-
/**
327-
* Delete connector for remote model
328-
* @param connectorId The id of the connector to delete
329-
* @return the result future
330-
*/
331-
default ActionFuture<DeleteResponse> deleteConnector(String connectorId, String tenantId) {
332-
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
333-
deleteConnector(connectorId, tenantId, actionFuture);
334-
return actionFuture;
348+
default void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
349+
deleteConnector(connectorId, null, listener);
335350
}
336351

337-
void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);
338-
339352
void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener);
340353

341354
/**

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+4-15
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
164164
}
165165

166166
@Override
167-
public void getModel(String modelId, ActionListener<MLModel> listener) {
168-
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();
167+
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
168+
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).tenantId(tenantId).build();
169169

170170
client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener));
171171
}
@@ -178,8 +178,8 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A
178178
}
179179

180180
@Override
181-
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
182-
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();
181+
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
182+
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).tenantId(tenantId).build();
183183

184184
client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
185185
}
@@ -259,17 +259,6 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio
259259
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener));
260260
}
261261

262-
@Override
263-
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
264-
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
265-
client
266-
.execute(
267-
MLConnectorDeleteAction.INSTANCE,
268-
connectorDeleteRequest,
269-
ActionListener.wrap(listener::onResponse, listener::onFailure)
270-
);
271-
}
272-
273262
@Override
274263
public void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener) {
275264
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId);

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+34
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ public class MachineLearningClientTest {
7676
@Mock
7777
ActionListener<MLOutput> dataFrameActionListener;
7878

79+
@Mock
80+
ActionListener<MLModel> mlModelActionListener;
81+
7982
@Mock
8083
DeleteResponse deleteResponse;
8184

@@ -166,11 +169,21 @@ public void getModel(String modelId, ActionListener<MLModel> listener) {
166169
listener.onResponse(mlModel);
167170
}
168171

172+
@Override
173+
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
174+
listener.onResponse(mlModel);
175+
}
176+
169177
@Override
170178
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
171179
listener.onResponse(deleteResponse);
172180
}
173181

182+
@Override
183+
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
184+
listener.onResponse(deleteResponse);
185+
}
186+
174187
@Override
175188
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
176189
listener.onResponse(searchResponse);
@@ -352,6 +365,22 @@ public void getModel() {
352365
assertEquals(mlModel, machineLearningClient.getModel("modelId").actionGet());
353366
}
354367

368+
@Test
369+
public void getModelActionListener() {
370+
ArgumentCaptor<MLModel> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLModel.class);
371+
machineLearningClient.getModel("modelId", mlModelActionListener);
372+
verify(mlModelActionListener).onResponse(dataFrameArgumentCaptor.capture());
373+
assertEquals(mlModel, dataFrameArgumentCaptor.getValue());
374+
assertEquals(mlModel.getTenantId(), dataFrameArgumentCaptor.getValue().getTenantId());
375+
}
376+
377+
@Test
378+
public void undeploy_WithSpecificNodes() {
379+
String[] modelIds = new String[] { "model1", "model2" };
380+
String[] nodeIds = new String[] { "node1", "node2" };
381+
assertEquals(undeployModelsResponse, machineLearningClient.undeploy(modelIds, nodeIds).actionGet());
382+
}
383+
355384
@Test
356385
public void deleteModel() {
357386
assertEquals(deleteResponse, machineLearningClient.deleteModel("modelId").actionGet());
@@ -362,6 +391,11 @@ public void searchModel() {
362391
assertEquals(searchResponse, machineLearningClient.searchModel(new SearchRequest()).actionGet());
363392
}
364393

394+
@Test
395+
public void deleteConnector_WithTenantId() {
396+
assertEquals(deleteResponse, machineLearningClient.deleteConnector("connectorId").actionGet());
397+
}
398+
365399
@Test
366400
public void registerModelGroup() {
367401
List<String> backendRoles = Arrays.asList("IT", "HR");

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+199
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertFalse;
1010
import static org.junit.Assert.assertTrue;
11+
import static org.junit.Assert.fail;
1112
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
1213
import static org.mockito.ArgumentMatchers.any;
1314
import static org.mockito.ArgumentMatchers.eq;
@@ -325,6 +326,64 @@ public void train() {
325326
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
326327
}
327328

329+
@Test
330+
public void getModel_withTenantId() {
331+
String modelContent = "test content";
332+
String tenantId = "tenantId";
333+
doAnswer(invocation -> {
334+
ActionListener<MLModelGetResponse> actionListener = invocation.getArgument(2);
335+
MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build();
336+
MLModelGetResponse output = MLModelGetResponse.builder().mlModel(mlModel).build();
337+
actionListener.onResponse(output);
338+
return null;
339+
}).when(client).execute(eq(MLModelGetAction.INSTANCE), any(), any());
340+
341+
ArgumentCaptor<MLModel> argumentCaptor = ArgumentCaptor.forClass(MLModel.class);
342+
machineLearningNodeClient.getModel("modelId", tenantId, getModelActionListener);
343+
344+
verify(client).execute(eq(MLModelGetAction.INSTANCE), isA(MLModelGetRequest.class), any());
345+
verify(getModelActionListener).onResponse(argumentCaptor.capture());
346+
assertEquals(FunctionName.KMEANS, argumentCaptor.getValue().getAlgorithm());
347+
assertEquals(modelContent, argumentCaptor.getValue().getContent());
348+
}
349+
350+
@Test
351+
public void undeployModels_withNullNodeIds() {
352+
doAnswer(invocation -> {
353+
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
354+
MLUndeployModelsResponse output = new MLUndeployModelsResponse(
355+
new MLUndeployModelNodesResponse(ClusterName.DEFAULT, Collections.emptyList(), Collections.emptyList())
356+
);
357+
actionListener.onResponse(output);
358+
return null;
359+
}).when(client).execute(eq(MLUndeployModelsAction.INSTANCE), any(), any());
360+
361+
machineLearningNodeClient.undeploy(new String[] { "model1" }, null, undeployModelsActionListener);
362+
verify(client).execute(eq(MLUndeployModelsAction.INSTANCE), isA(MLUndeployModelsRequest.class), any());
363+
}
364+
365+
@Test
366+
public void createConnector_withValidInput() {
367+
doAnswer(invocation -> {
368+
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2);
369+
MLCreateConnectorResponse output = new MLCreateConnectorResponse("connectorId");
370+
actionListener.onResponse(output);
371+
return null;
372+
}).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any());
373+
374+
MLCreateConnectorInput input = MLCreateConnectorInput
375+
.builder()
376+
.name("testConnector")
377+
.protocol("http")
378+
.version("1")
379+
.credential(Map.of("TEST_CREDENTIAL_KEY", "TEST_CREDENTIAL_VALUE"))
380+
.parameters(Map.of("endpoint", "https://example.com"))
381+
.build();
382+
383+
machineLearningNodeClient.createConnector(input, createConnectorActionListener);
384+
verify(client).execute(eq(MLCreateConnectorAction.INSTANCE), isA(MLCreateConnectorRequest.class), any());
385+
}
386+
328387
@Test
329388
public void registerModelGroup_withValidInput() {
330389
doAnswer(invocation -> {
@@ -346,6 +405,146 @@ public void registerModelGroup_withValidInput() {
346405
verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any());
347406
}
348407

408+
@Test
409+
public void listTools_withValidRequest() {
410+
doAnswer(invocation -> {
411+
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
412+
MLToolsListResponse output = MLToolsListResponse
413+
.builder()
414+
.toolMetadata(
415+
Arrays
416+
.asList(
417+
ToolMetadata.builder().name("tool1").description("description1").build(),
418+
ToolMetadata.builder().name("tool2").description("description2").build()
419+
)
420+
)
421+
.build();
422+
actionListener.onResponse(output);
423+
return null;
424+
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());
425+
426+
machineLearningNodeClient.listTools(listToolsActionListener);
427+
verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
428+
}
429+
430+
@Test
431+
public void listTools_withEmptyResponse() {
432+
doAnswer(invocation -> {
433+
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
434+
MLToolsListResponse output = MLToolsListResponse.builder().toolMetadata(Collections.emptyList()).build();
435+
actionListener.onResponse(output);
436+
return null;
437+
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());
438+
439+
ArgumentCaptor<List<ToolMetadata>> argumentCaptor = ArgumentCaptor.forClass(List.class);
440+
machineLearningNodeClient.listTools(listToolsActionListener);
441+
442+
verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
443+
verify(listToolsActionListener).onResponse(argumentCaptor.capture());
444+
445+
List<ToolMetadata> capturedTools = argumentCaptor.getValue();
446+
assertTrue(capturedTools.isEmpty());
447+
}
448+
449+
@Test
450+
public void getTool_withValidToolName() {
451+
doAnswer(invocation -> {
452+
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
453+
MLToolGetResponse output = MLToolGetResponse
454+
.builder()
455+
.toolMetadata(ToolMetadata.builder().name("tool1").description("description1").build())
456+
.build();
457+
actionListener.onResponse(output);
458+
return null;
459+
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());
460+
461+
machineLearningNodeClient.getTool("tool1", getToolActionListener);
462+
verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
463+
}
464+
465+
@Test
466+
public void getTool_withValidRequest() {
467+
ToolMetadata toolMetadata = ToolMetadata
468+
.builder()
469+
.name("MathTool")
470+
.description("Use this tool to calculate any math problem.")
471+
.build();
472+
473+
doAnswer(invocation -> {
474+
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
475+
MLToolGetResponse output = MLToolGetResponse.builder().toolMetadata(toolMetadata).build();
476+
actionListener.onResponse(output);
477+
return null;
478+
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());
479+
480+
ArgumentCaptor<ToolMetadata> argumentCaptor = ArgumentCaptor.forClass(ToolMetadata.class);
481+
machineLearningNodeClient.getTool("MathTool", getToolActionListener);
482+
483+
verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
484+
verify(getToolActionListener).onResponse(argumentCaptor.capture());
485+
486+
ToolMetadata capturedTool = argumentCaptor.getValue();
487+
assertEquals("MathTool", capturedTool.getName());
488+
assertEquals("Use this tool to calculate any math problem.", capturedTool.getDescription());
489+
}
490+
491+
@Test
492+
public void getTool_withFailureResponse() {
493+
doAnswer(invocation -> {
494+
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
495+
actionListener.onFailure(new RuntimeException("Test exception"));
496+
return null;
497+
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());
498+
499+
machineLearningNodeClient.getTool("MathTool", new ActionListener<>() {
500+
@Override
501+
public void onResponse(ToolMetadata toolMetadata) {
502+
fail("Expected failure but got response");
503+
}
504+
505+
@Override
506+
public void onFailure(Exception e) {
507+
assertEquals("Test exception", e.getMessage());
508+
}
509+
});
510+
511+
verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
512+
}
513+
514+
@Test
515+
public void train_withAsync() {
516+
doAnswer(invocation -> {
517+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
518+
MLTrainingOutput output = MLTrainingOutput.builder().status("InProgress").modelId("modelId").build();
519+
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
520+
return null;
521+
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
522+
523+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
524+
machineLearningNodeClient.train(mlInput, true, trainingActionListener);
525+
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any());
526+
}
527+
528+
@Test
529+
public void deleteModel_withTenantId() {
530+
String modelId = "testModelId";
531+
String tenantId = "tenantId";
532+
doAnswer(invocation -> {
533+
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
534+
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
535+
DeleteResponse output = new DeleteResponse(shardId, modelId, 1, 1, 1, true);
536+
actionListener.onResponse(output);
537+
return null;
538+
}).when(client).execute(eq(MLModelDeleteAction.INSTANCE), any(), any());
539+
540+
ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
541+
machineLearningNodeClient.deleteModel(modelId, tenantId, deleteModelActionListener);
542+
543+
verify(client).execute(eq(MLModelDeleteAction.INSTANCE), isA(MLModelDeleteRequest.class), any());
544+
verify(deleteModelActionListener).onResponse(argumentCaptor.capture());
545+
assertEquals(modelId, argumentCaptor.getValue().getId());
546+
}
547+
349548
@Test
350549
public void train_Exception_WithNullDataSet() {
351550
exceptionRule.expect(IllegalArgumentException.class);

0 commit comments

Comments
 (0)