Skip to content

Commit 66d4cd4

Browse files
authored
Added register model group API for MLClient (opensearch-project#1493)
* Added register model group API for MLClient Signed-off-by: Owais Kazi <owaiskazi19@gmail.com> * Resolved formatting errors Signed-off-by: Owais Kazi <owaiskazi19@gmail.com> --------- Signed-off-by: Owais Kazi <owaiskazi19@gmail.com>
1 parent d265fc7 commit 66d4cd4

File tree

9 files changed

+412
-376
lines changed

9 files changed

+412
-376
lines changed

client/build.gradle

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ plugins {
99
id 'jacoco'
1010
id 'com.github.johnrengelman.shadow'
1111
id 'maven-publish'
12+
id 'com.diffplug.spotless' version '6.18.0'
1213
id 'signing'
1314
}
1415

@@ -20,6 +21,15 @@ dependencies {
2021

2122
}
2223

24+
spotless {
25+
java {
26+
removeUnusedImports()
27+
importOrder 'java', 'javax', 'org', 'com'
28+
29+
eclipse().configFile rootProject.file('.eclipseformat.xml')
30+
}
31+
}
32+
2333
jacocoTestReport {
2434
reports {
2535
xml.getRequired().set(true)

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

+24-7
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,26 @@
55

66
package org.opensearch.ml.client;
77

8+
import java.util.Map;
89

9-
import org.opensearch.common.action.ActionFuture;
10-
import org.opensearch.core.action.ActionListener;
1110
import org.opensearch.action.delete.DeleteResponse;
1211
import org.opensearch.action.search.SearchRequest;
1312
import org.opensearch.action.search.SearchResponse;
1413
import org.opensearch.action.support.PlainActionFuture;
14+
import org.opensearch.common.action.ActionFuture;
15+
import org.opensearch.core.action.ActionListener;
1516
import org.opensearch.ml.common.MLModel;
1617
import org.opensearch.ml.common.MLTask;
1718
import org.opensearch.ml.common.input.MLInput;
1819
import org.opensearch.ml.common.output.MLOutput;
1920
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
2021
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
2122
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
23+
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
24+
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
2225
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
2326
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
2427

25-
import java.util.Map;
26-
2728
/**
2829
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
2930
*/
@@ -84,7 +85,6 @@ default ActionFuture<MLOutput> train(MLInput mlInput, boolean asyncTask) {
8485
return actionFuture;
8586
}
8687

87-
8888
/**
8989
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
9090
* For more info on train model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#train-model
@@ -205,15 +205,13 @@ default ActionFuture<SearchResponse> searchModel(SearchRequest searchRequest) {
205205
return actionFuture;
206206
}
207207

208-
209208
/**
210209
* For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model
211210
* @param searchRequest searchRequest to search the ML Model
212211
* @param listener action listener
213212
*/
214213
void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener);
215214

216-
217215
/**
218216
* For more info on search task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-task
219217
* @param searchRequest searchRequest to search the ML Task
@@ -282,4 +280,23 @@ default ActionFuture<MLCreateConnectorResponse> createConnector(MLCreateConnecto
282280
}
283281

284282
void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener);
283+
284+
/**
285+
* Register model group
286+
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
287+
* @param mlRegisterModelGroupInput model group input
288+
*/
289+
default ActionFuture<MLRegisterModelGroupResponse> registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput) {
290+
PlainActionFuture<MLRegisterModelGroupResponse> actionFuture = PlainActionFuture.newFuture();
291+
registerModelGroup(mlRegisterModelGroupInput, actionFuture);
292+
return actionFuture;
293+
}
294+
295+
/**
296+
* Register model group
297+
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
298+
* @param mlRegisterModelGroupInput model group input
299+
* @param listener a listener to be notified of the result
300+
*/
301+
void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener<MLRegisterModelGroupResponse> listener);
285302
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+80-78
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,36 @@
55

66
package org.opensearch.ml.client;
77

8-
import lombok.AccessLevel;
9-
import lombok.RequiredArgsConstructor;
10-
import lombok.experimental.FieldDefaults;
11-
import org.opensearch.action.index.IndexRequest;
12-
import org.opensearch.common.util.concurrent.ThreadContext;
13-
import org.opensearch.common.xcontent.XContentType;
14-
import org.opensearch.core.action.ActionListener;
15-
import org.opensearch.core.action.ActionResponse;
8+
import static org.opensearch.ml.common.input.Constants.ASYNC;
9+
import static org.opensearch.ml.common.input.Constants.MODELID;
10+
import static org.opensearch.ml.common.input.Constants.PREDICT;
11+
import static org.opensearch.ml.common.input.Constants.TRAIN;
12+
import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT;
13+
import static org.opensearch.ml.common.input.InputHelper.convertArgumentToMLParameter;
14+
import static org.opensearch.ml.common.input.InputHelper.getAction;
15+
import static org.opensearch.ml.common.input.InputHelper.getFunctionName;
16+
17+
import java.util.Map;
18+
import java.util.function.Function;
19+
1620
import org.opensearch.action.delete.DeleteResponse;
1721
import org.opensearch.action.search.SearchRequest;
1822
import org.opensearch.action.search.SearchResponse;
1923
import org.opensearch.client.Client;
20-
import org.opensearch.core.xcontent.ToXContent;
21-
import org.opensearch.core.xcontent.XContentBuilder;
22-
import org.opensearch.ml.common.AccessMode;
24+
import org.opensearch.core.action.ActionListener;
25+
import org.opensearch.core.action.ActionResponse;
2326
import org.opensearch.ml.common.FunctionName;
2427
import org.opensearch.ml.common.MLModel;
25-
import org.opensearch.ml.common.MLModelGroup;
2628
import org.opensearch.ml.common.MLTask;
27-
import org.opensearch.ml.common.exception.MLException;
2829
import org.opensearch.ml.common.input.MLInput;
2930
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
30-
import org.opensearch.ml.common.model.MLModelConfig;
31-
import org.opensearch.ml.common.model.MLModelFormat;
32-
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
3331
import org.opensearch.ml.common.output.MLOutput;
3432
import org.opensearch.ml.common.transport.MLTaskResponse;
3533
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
3634
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
3735
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
3836
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
3937
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
40-
import org.opensearch.ml.common.transport.deploy.MLDeployModelInput;
4138
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
4239
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
4340
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
@@ -46,7 +43,10 @@
4643
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
4744
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
4845
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
49-
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
46+
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction;
47+
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
48+
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest;
49+
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
5050
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
5151
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
5252
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
@@ -63,20 +63,9 @@
6363
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
6464
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
6565

66-
import java.io.IOException;
67-
import java.time.Instant;
68-
import java.util.Map;
69-
import java.util.function.Function;
70-
71-
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
72-
import static org.opensearch.ml.common.input.Constants.ASYNC;
73-
import static org.opensearch.ml.common.input.Constants.MODELID;
74-
import static org.opensearch.ml.common.input.Constants.PREDICT;
75-
import static org.opensearch.ml.common.input.Constants.TRAIN;
76-
import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT;
77-
import static org.opensearch.ml.common.input.InputHelper.convertArgumentToMLParameter;
78-
import static org.opensearch.ml.common.input.InputHelper.getAction;
79-
import static org.opensearch.ml.common.input.InputHelper.getFunctionName;
66+
import lombok.AccessLevel;
67+
import lombok.RequiredArgsConstructor;
68+
import lombok.experimental.FieldDefaults;
8069

8170
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
8271
@RequiredArgsConstructor
@@ -88,33 +77,32 @@ public class MachineLearningNodeClient implements MachineLearningClient {
8877
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
8978
validateMLInput(mlInput, true);
9079

91-
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder()
92-
.mlInput(mlInput)
93-
.modelId(modelId)
94-
.dispatchTask(true)
95-
.build();
80+
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
81+
.builder()
82+
.mlInput(mlInput)
83+
.modelId(modelId)
84+
.dispatchTask(true)
85+
.build();
9686
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener));
9787
}
9888

9989
@Override
10090
public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
10191
validateMLInput(mlInput, true);
102-
MLTrainingTaskRequest request = MLTrainingTaskRequest.builder()
103-
.mlInput(mlInput)
104-
.dispatchTask(true)
105-
.build();
92+
MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).dispatchTask(true).build();
10693

10794
client.execute(MLTrainAndPredictionTaskAction.INSTANCE, request, getMlPredictionTaskResponseActionListener(listener));
10895
}
10996

11097
@Override
11198
public void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> listener) {
11299
validateMLInput(mlInput, true);
113-
MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder()
114-
.mlInput(mlInput)
115-
.async(asyncTask)
116-
.dispatchTask(true)
117-
.build();
100+
MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest
101+
.builder()
102+
.mlInput(mlInput)
103+
.async(asyncTask)
104+
.dispatchTask(true)
105+
.build();
118106

119107
client.execute(MLTrainingTaskAction.INSTANCE, trainingTaskRequest, getMlPredictionTaskResponseActionListener(listener));
120108
}
@@ -144,15 +132,13 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
144132
trainAndPredict(mlInput, listener);
145133
break;
146134
default:
147-
throw new IllegalArgumentException("Unsupported action.");
135+
throw new IllegalArgumentException("Unsupported action.");
148136
}
149137
}
150138

151139
@Override
152140
public void getModel(String modelId, ActionListener<MLModel> listener) {
153-
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder()
154-
.modelId(modelId)
155-
.build();
141+
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();
156142

157143
client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener));
158144
}
@@ -170,9 +156,7 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A
170156

171157
@Override
172158
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
173-
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder()
174-
.modelId(modelId)
175-
.build();
159+
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();
176160

177161
client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> {
178162
listener.onResponse(deleteResponse);
@@ -181,17 +165,26 @@ public void deleteModel(String modelId, ActionListener<DeleteResponse> listener)
181165

182166
@Override
183167
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
184-
client.execute(MLModelSearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
185-
listener.onResponse(searchResponse);
186-
}, listener::onFailure));
168+
client
169+
.execute(
170+
MLModelSearchAction.INSTANCE,
171+
searchRequest,
172+
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
173+
);
187174
}
188175

176+
@Override
177+
public void registerModelGroup(
178+
MLRegisterModelGroupInput mlRegisterModelGroupInput,
179+
ActionListener<MLRegisterModelGroupResponse> listener
180+
) {
181+
MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest(mlRegisterModelGroupInput);
182+
client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener);
183+
}
189184

190185
@Override
191186
public void getTask(String taskId, ActionListener<MLTask> listener) {
192-
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder()
193-
.taskId(taskId)
194-
.build();
187+
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
195188

196189
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(response -> {
197190
listener.onResponse(MLTaskGetResponse.fromActionResponse(response).getMlTask());
@@ -200,9 +193,7 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
200193

201194
@Override
202195
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
203-
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder()
204-
.taskId(taskId)
205-
.build();
196+
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();
206197

207198
client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> {
208199
listener.onResponse(deleteResponse);
@@ -211,25 +202,34 @@ public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
211202

212203
@Override
213204
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
214-
client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
215-
listener.onResponse(searchResponse);
216-
}, listener::onFailure));
205+
client
206+
.execute(
207+
MLTaskSearchAction.INSTANCE,
208+
searchRequest,
209+
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
210+
);
217211
}
218212

219213
@Override
220214
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
221215
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
222-
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
223-
listener.onFailure(e);
224-
}));
216+
client
217+
.execute(
218+
MLRegisterModelAction.INSTANCE,
219+
registerRequest,
220+
ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
221+
);
225222
}
226223

227224
@Override
228225
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
229226
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
230-
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> {
231-
listener.onFailure(e);
232-
}));
227+
client
228+
.execute(
229+
MLDeployModelAction.INSTANCE,
230+
deployModelRequest,
231+
ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
232+
);
233233
}
234234

235235
@Override
@@ -249,20 +249,22 @@ private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener
249249
return actionListener;
250250
}
251251

252-
private <T extends ActionResponse> ActionListener<T> wrapActionListener(final ActionListener<T> listener, final Function<ActionResponse, T> recreate) {
253-
ActionListener<T> actionListener = ActionListener.wrap(r-> {
254-
listener.onResponse(recreate.apply(r));;
255-
}, e->{
256-
listener.onFailure(e);
257-
});
252+
private <T extends ActionResponse> ActionListener<T> wrapActionListener(
253+
final ActionListener<T> listener,
254+
final Function<ActionResponse, T> recreate
255+
) {
256+
ActionListener<T> actionListener = ActionListener.wrap(r -> {
257+
listener.onResponse(recreate.apply(r));
258+
;
259+
}, e -> { listener.onFailure(e); });
258260
return actionListener;
259261
}
260262

261263
private void validateMLInput(MLInput mlInput, boolean requireInput) {
262264
if (mlInput == null) {
263265
throw new IllegalArgumentException("ML Input can't be null");
264266
}
265-
if(requireInput && mlInput.getInputDataset() == null) {
267+
if (requireInput && mlInput.getInputDataset() == null) {
266268
throw new IllegalArgumentException("input data set can't be null");
267269
}
268270
}

client/src/main/java/org/opensearch/ml/client/package-info.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.client;
6+
package org.opensearch.ml.client;

0 commit comments

Comments
 (0)