Skip to content

Commit 1693596

Browse files
apply multi-tenancy and sdk client in Connector (Create + Get + Delete) (opensearch-project#3382) (opensearch-project#3385)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit bcb00d1) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 7f42329 commit 1693596

28 files changed

+1484
-500
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+13
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,21 @@ default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
323323
return actionFuture;
324324
}
325325

326+
/**
327+
* Delete connector for remote model
328+
* @param connectorId The id of the connector to delete
329+
* @return the result future
330+
*/
331+
default ActionFuture<DeleteResponse> deleteConnector(String connectorId, String tenantId) {
332+
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
333+
deleteConnector(connectorId, tenantId, actionFuture);
334+
return actionFuture;
335+
}
336+
326337
void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);
327338

339+
void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener);
340+
328341
/**
329342
* Register model group
330343
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+37-88
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
146146
mlInput.setParameters(mlAlgoParams);
147147
switch (action) {
148148
case TRAIN:
149-
boolean asyncTask = args.containsKey(ASYNC) ? (boolean) args.get(ASYNC) : false;
149+
boolean asyncTask = args.containsKey(ASYNC) && (boolean) args.get(ASYNC);
150150
train(mlInput, asyncTask, listener);
151151
break;
152152
case PREDICT:
@@ -174,30 +174,19 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A
174174
ActionListener<MLModelGetResponse> internalListener = ActionListener.wrap(predictionResponse -> {
175175
listener.onResponse(predictionResponse.getMlModel());
176176
}, listener::onFailure);
177-
ActionListener<MLModelGetResponse> actionListener = wrapActionListener(internalListener, res -> {
178-
MLModelGetResponse getResponse = MLModelGetResponse.fromActionResponse(res);
179-
return getResponse;
180-
});
181-
return actionListener;
177+
return wrapActionListener(internalListener, MLModelGetResponse::fromActionResponse);
182178
}
183179

184180
@Override
185181
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
186182
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();
187183

188-
client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> {
189-
listener.onResponse(deleteResponse);
190-
}, listener::onFailure));
184+
client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
191185
}
192186

193187
@Override
194188
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
195-
client
196-
.execute(
197-
MLModelSearchAction.INSTANCE,
198-
searchRequest,
199-
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
200-
);
189+
client.execute(MLModelSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
201190
}
202191

203192
@Override
@@ -238,19 +227,12 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
238227
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
239228
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();
240229

241-
client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> {
242-
listener.onResponse(deleteResponse);
243-
}, listener::onFailure));
230+
client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
244231
}
245232

246233
@Override
247234
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
248-
client
249-
.execute(
250-
MLTaskSearchAction.INSTANCE,
251-
searchRequest,
252-
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
253-
);
235+
client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
254236
}
255237

256238
@Override
@@ -280,9 +262,23 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio
280262
@Override
281263
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
282264
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
283-
client.execute(MLConnectorDeleteAction.INSTANCE, connectorDeleteRequest, ActionListener.wrap(deleteResponse -> {
284-
listener.onResponse(deleteResponse);
285-
}, listener::onFailure));
265+
client
266+
.execute(
267+
MLConnectorDeleteAction.INSTANCE,
268+
connectorDeleteRequest,
269+
ActionListener.wrap(listener::onResponse, listener::onFailure)
270+
);
271+
}
272+
273+
@Override
274+
public void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener) {
275+
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId);
276+
client
277+
.execute(
278+
MLConnectorDeleteAction.INSTANCE,
279+
connectorDeleteRequest,
280+
ActionListener.wrap(listener::onResponse, listener::onFailure)
281+
);
286282
}
287283

288284
@Override
@@ -294,9 +290,7 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
294290
@Override
295291
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
296292
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
297-
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> {
298-
listener.onResponse(deleteResponse);
299-
}, listener::onFailure));
293+
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
300294
}
301295

302296
@Override
@@ -324,123 +318,78 @@ private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener
324318
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
325319
listener.onResponse(mlModelListResponse.getToolMetadataList());
326320
}, listener::onFailure);
327-
ActionListener<MLToolsListResponse> actionListener = wrapActionListener(internalListener, res -> {
328-
MLToolsListResponse getResponse = MLToolsListResponse.fromActionResponse(res);
329-
return getResponse;
330-
});
331-
return actionListener;
321+
return wrapActionListener(internalListener, MLToolsListResponse::fromActionResponse);
332322
}
333323

334324
private ActionListener<MLToolGetResponse> getMlGetToolResponseActionListener(ActionListener<ToolMetadata> listener) {
335325
ActionListener<MLToolGetResponse> internalListener = ActionListener.wrap(mlModelGetResponse -> {
336326
listener.onResponse(mlModelGetResponse.getToolMetadata());
337327
}, listener::onFailure);
338-
ActionListener<MLToolGetResponse> actionListener = wrapActionListener(internalListener, res -> {
339-
MLToolGetResponse getResponse = MLToolGetResponse.fromActionResponse(res);
340-
return getResponse;
341-
});
342-
return actionListener;
328+
return wrapActionListener(internalListener, MLToolGetResponse::fromActionResponse);
343329
}
344330

345331
private ActionListener<MLConfigGetResponse> getMlGetConfigResponseActionListener(ActionListener<MLConfig> listener) {
346332
ActionListener<MLConfigGetResponse> internalListener = ActionListener.wrap(mlConfigGetResponse -> {
347333
listener.onResponse(mlConfigGetResponse.getMlConfig());
348334
}, listener::onFailure);
349-
ActionListener<MLConfigGetResponse> actionListener = wrapActionListener(internalListener, res -> {
350-
MLConfigGetResponse getResponse = MLConfigGetResponse.fromActionResponse(res);
351-
return getResponse;
352-
});
353-
return actionListener;
335+
return wrapActionListener(internalListener, MLConfigGetResponse::fromActionResponse);
354336
}
355337

356338
private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
357339
ActionListener<MLRegisterAgentResponse> listener
358340
) {
359-
ActionListener<MLRegisterAgentResponse> actionListener = wrapActionListener(listener, res -> {
360-
MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res);
361-
return mlRegisterAgentResponse;
362-
});
363-
return actionListener;
341+
return wrapActionListener(listener, MLRegisterAgentResponse::fromActionResponse);
364342
}
365343

366344
private ActionListener<MLTaskGetResponse> getMLTaskResponseActionListener(ActionListener<MLTask> listener) {
367345
ActionListener<MLTaskGetResponse> internalListener = ActionListener
368346
.wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure);
369-
ActionListener<MLTaskGetResponse> actionListener = wrapActionListener(internalListener, response -> {
370-
MLTaskGetResponse getResponse = MLTaskGetResponse.fromActionResponse(response);
371-
return getResponse;
372-
});
373-
return actionListener;
347+
return wrapActionListener(internalListener, MLTaskGetResponse::fromActionResponse);
374348
}
375349

376350
private ActionListener<MLDeployModelResponse> getMlDeployModelResponseActionListener(ActionListener<MLDeployModelResponse> listener) {
377-
ActionListener<MLDeployModelResponse> actionListener = wrapActionListener(listener, response -> {
378-
MLDeployModelResponse deployModelResponse = MLDeployModelResponse.fromActionResponse(response);
379-
return deployModelResponse;
380-
});
381-
return actionListener;
351+
return wrapActionListener(listener, MLDeployModelResponse::fromActionResponse);
382352
}
383353

384354
private ActionListener<MLUndeployModelsResponse> getMlUndeployModelsResponseActionListener(
385355
ActionListener<MLUndeployModelsResponse> listener
386356
) {
387-
ActionListener<MLUndeployModelsResponse> actionListener = wrapActionListener(listener, response -> {
388-
MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response);
389-
return deployModelResponse;
390-
});
391-
return actionListener;
357+
return wrapActionListener(listener, MLUndeployModelsResponse::fromActionResponse);
392358
}
393359

394360
private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseActionListener(
395361
ActionListener<MLCreateConnectorResponse> listener
396362
) {
397-
ActionListener<MLCreateConnectorResponse> actionListener = wrapActionListener(listener, response -> {
398-
MLCreateConnectorResponse createConnectorResponse = MLCreateConnectorResponse.fromActionResponse(response);
399-
return createConnectorResponse;
400-
});
401-
return actionListener;
363+
return wrapActionListener(listener, MLCreateConnectorResponse::fromActionResponse);
402364
}
403365

404366
private ActionListener<MLRegisterModelGroupResponse> getMlRegisterModelGroupResponseActionListener(
405367
ActionListener<MLRegisterModelGroupResponse> listener
406368
) {
407-
ActionListener<MLRegisterModelGroupResponse> actionListener = wrapActionListener(listener, response -> {
408-
MLRegisterModelGroupResponse registerModelGroupResponse = MLRegisterModelGroupResponse.fromActionResponse(response);
409-
return registerModelGroupResponse;
410-
});
411-
return actionListener;
369+
return wrapActionListener(listener, MLRegisterModelGroupResponse::fromActionResponse);
412370
}
413371

414372
private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
415373
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
416374
listener.onResponse(predictionResponse.getOutput());
417375
}, listener::onFailure);
418-
ActionListener<MLTaskResponse> actionListener = wrapActionListener(internalListener, res -> {
419-
MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res);
420-
return predictionResponse;
421-
});
422-
return actionListener;
376+
return wrapActionListener(internalListener, MLTaskResponse::fromActionResponse);
423377
}
424378

425379
private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseActionListener(
426380
ActionListener<MLRegisterModelResponse> listener
427381
) {
428-
ActionListener<MLRegisterModelResponse> actionListener = wrapActionListener(listener, res -> {
429-
MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res);
430-
return registerModelResponse;
431-
});
432-
return actionListener;
382+
return wrapActionListener(listener, MLRegisterModelResponse::fromActionResponse);
433383
}
434384

435385
private <T extends ActionResponse> ActionListener<T> wrapActionListener(
436386
final ActionListener<T> listener,
437387
final Function<ActionResponse, T> recreate
438388
) {
439-
ActionListener<T> actionListener = ActionListener.wrap(r -> {
389+
return ActionListener.wrap(r -> {
440390
listener.onResponse(recreate.apply(r));
441391
;
442-
}, e -> { listener.onFailure(e); });
443-
return actionListener;
392+
}, listener::onFailure);
444393
}
445394

446395
private void validateMLInput(MLInput mlInput, boolean requireInput) {

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

+5
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ public void execute(FunctionName name, Input input, ActionListener<MLExecuteTask
216216
listener.onResponse(mlExecuteTaskResponse);
217217
}
218218

219+
@Override
220+
public void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener) {
221+
listener.onResponse(deleteResponse);
222+
}
223+
219224
@Override
220225
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
221226
listener.onResponse(deleteResponse);

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

+78
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
1818
import static org.opensearch.ml.common.input.Constants.ACTION;
1919
import static org.opensearch.ml.common.input.Constants.ALGORITHM;
20+
import static org.opensearch.ml.common.input.Constants.ASYNC;
2021
import static org.opensearch.ml.common.input.Constants.KMEANS;
2122
import static org.opensearch.ml.common.input.Constants.MODELID;
2223
import static org.opensearch.ml.common.input.Constants.PREDICT;
@@ -251,6 +252,42 @@ public void predict() {
251252
assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
252253
}
253254

255+
@Test
256+
public void execute_train_asyncTask() {
257+
String modelId = "test_model_id";
258+
String status = "InProgress";
259+
doAnswer(invocation -> {
260+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
261+
MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
262+
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
263+
return null;
264+
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
265+
266+
ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
267+
Map<String, Object> args = new HashMap<>();
268+
args.put(ACTION, TRAIN);
269+
args.put(ALGORITHM, KMEANS);
270+
args.put(ASYNC, true);
271+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build();
272+
machineLearningNodeClient.run(mlInput, args, trainingActionListener);
273+
274+
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any());
275+
verify(trainingActionListener).onResponse(argumentCaptor.capture());
276+
assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
277+
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
278+
}
279+
280+
@Test
281+
public void execute_predict_missing_modelId() {
282+
exceptionRule.expect(IllegalArgumentException.class);
283+
exceptionRule.expectMessage("The model ID is required for prediction.");
284+
Map<String, Object> args = new HashMap<>();
285+
args.put(ACTION, PREDICT);
286+
args.put(ALGORITHM, KMEANS);
287+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build();
288+
machineLearningNodeClient.run(mlInput, args, dataFrameActionListener);
289+
}
290+
254291
@Test
255292
public void predict_Exception_WithNullAlgorithm() {
256293
exceptionRule.expect(IllegalArgumentException.class);
@@ -288,6 +325,27 @@ public void train() {
288325
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
289326
}
290327

328+
@Test
329+
public void registerModelGroup_withValidInput() {
330+
doAnswer(invocation -> {
331+
ActionListener<MLRegisterModelGroupResponse> actionListener = invocation.getArgument(2);
332+
MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse("groupId", "created");
333+
actionListener.onResponse(output);
334+
return null;
335+
}).when(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), any(), any());
336+
337+
MLRegisterModelGroupInput input = MLRegisterModelGroupInput
338+
.builder()
339+
.name("test")
340+
.description("description")
341+
.backendRoles(Arrays.asList("role1", "role2"))
342+
.modelAccessMode(AccessMode.PUBLIC)
343+
.build();
344+
345+
machineLearningNodeClient.registerModelGroup(input, registerModelGroupResponseActionListener);
346+
verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any());
347+
}
348+
291349
@Test
292350
public void train_Exception_WithNullDataSet() {
293351
exceptionRule.expect(IllegalArgumentException.class);
@@ -499,6 +557,26 @@ public void getModel() {
499557
assertEquals(modelContent, argumentCaptor.getValue().getContent());
500558
}
501559

560+
@Test
561+
public void deleteConnector_withTenantId() {
562+
String connectorId = "connectorId";
563+
String tenantId = "tenantId";
564+
doAnswer(invocation -> {
565+
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
566+
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
567+
DeleteResponse output = new DeleteResponse(shardId, connectorId, 1, 1, 1, true);
568+
actionListener.onResponse(output);
569+
return null;
570+
}).when(client).execute(eq(MLConnectorDeleteAction.INSTANCE), any(), any());
571+
572+
ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
573+
machineLearningNodeClient.deleteConnector(connectorId, tenantId, deleteConnectorActionListener);
574+
575+
verify(client).execute(eq(MLConnectorDeleteAction.INSTANCE), isA(MLConnectorDeleteRequest.class), any());
576+
verify(deleteConnectorActionListener).onResponse(argumentCaptor.capture());
577+
assertEquals(connectorId, (argumentCaptor.getValue()).getId());
578+
}
579+
502580
@Test
503581
public void deleteModel() {
504582
String modelId = "testModelId";

common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ public AwsConnector(
4141
List<String> backendRoles,
4242
AccessMode accessMode,
4343
User owner,
44-
ConnectorClientConfig connectorClientConfig
44+
ConnectorClientConfig connectorClientConfig,
45+
String tenantId
4546
) {
4647
super(
4748
name,
@@ -54,7 +55,8 @@ public AwsConnector(
5455
backendRoles,
5556
accessMode,
5657
owner,
57-
connectorClientConfig
58+
connectorClientConfig,
59+
tenantId
5860
);
5961
validate();
6062
}

0 commit comments

Comments
 (0)