5
5
6
6
package org .opensearch .ml .client ;
7
7
8
- import lombok .AccessLevel ;
9
- import lombok .RequiredArgsConstructor ;
10
- import lombok .experimental .FieldDefaults ;
11
- import org .opensearch .action .index .IndexRequest ;
12
- import org .opensearch .common .util .concurrent .ThreadContext ;
13
- import org .opensearch .common .xcontent .XContentType ;
14
- import org .opensearch .core .action .ActionListener ;
15
- import org .opensearch .core .action .ActionResponse ;
8
+ import static org .opensearch .ml .common .input .Constants .ASYNC ;
9
+ import static org .opensearch .ml .common .input .Constants .MODELID ;
10
+ import static org .opensearch .ml .common .input .Constants .PREDICT ;
11
+ import static org .opensearch .ml .common .input .Constants .TRAIN ;
12
+ import static org .opensearch .ml .common .input .Constants .TRAINANDPREDICT ;
13
+ import static org .opensearch .ml .common .input .InputHelper .convertArgumentToMLParameter ;
14
+ import static org .opensearch .ml .common .input .InputHelper .getAction ;
15
+ import static org .opensearch .ml .common .input .InputHelper .getFunctionName ;
16
+
17
+ import java .util .Map ;
18
+ import java .util .function .Function ;
19
+
16
20
import org .opensearch .action .delete .DeleteResponse ;
17
21
import org .opensearch .action .search .SearchRequest ;
18
22
import org .opensearch .action .search .SearchResponse ;
19
23
import org .opensearch .client .Client ;
20
- import org .opensearch .core .xcontent .ToXContent ;
21
- import org .opensearch .core .xcontent .XContentBuilder ;
22
- import org .opensearch .ml .common .AccessMode ;
24
+ import org .opensearch .core .action .ActionListener ;
25
+ import org .opensearch .core .action .ActionResponse ;
23
26
import org .opensearch .ml .common .FunctionName ;
24
27
import org .opensearch .ml .common .MLModel ;
25
- import org .opensearch .ml .common .MLModelGroup ;
26
28
import org .opensearch .ml .common .MLTask ;
27
- import org .opensearch .ml .common .exception .MLException ;
28
29
import org .opensearch .ml .common .input .MLInput ;
29
30
import org .opensearch .ml .common .input .parameter .MLAlgoParams ;
30
- import org .opensearch .ml .common .model .MLModelConfig ;
31
- import org .opensearch .ml .common .model .MLModelFormat ;
32
- import org .opensearch .ml .common .model .MetricsCorrelationModelConfig ;
33
31
import org .opensearch .ml .common .output .MLOutput ;
34
32
import org .opensearch .ml .common .transport .MLTaskResponse ;
35
33
import org .opensearch .ml .common .transport .connector .MLCreateConnectorAction ;
36
34
import org .opensearch .ml .common .transport .connector .MLCreateConnectorInput ;
37
35
import org .opensearch .ml .common .transport .connector .MLCreateConnectorRequest ;
38
36
import org .opensearch .ml .common .transport .connector .MLCreateConnectorResponse ;
39
37
import org .opensearch .ml .common .transport .deploy .MLDeployModelAction ;
40
- import org .opensearch .ml .common .transport .deploy .MLDeployModelInput ;
41
38
import org .opensearch .ml .common .transport .deploy .MLDeployModelRequest ;
42
39
import org .opensearch .ml .common .transport .deploy .MLDeployModelResponse ;
43
40
import org .opensearch .ml .common .transport .model .MLModelDeleteAction ;
46
43
import org .opensearch .ml .common .transport .model .MLModelGetRequest ;
47
44
import org .opensearch .ml .common .transport .model .MLModelGetResponse ;
48
45
import org .opensearch .ml .common .transport .model .MLModelSearchAction ;
49
- import org .opensearch .ml .common .transport .model_group .MLModelGroupSearchAction ;
46
+ import org .opensearch .ml .common .transport .model_group .MLRegisterModelGroupAction ;
47
+ import org .opensearch .ml .common .transport .model_group .MLRegisterModelGroupInput ;
48
+ import org .opensearch .ml .common .transport .model_group .MLRegisterModelGroupRequest ;
49
+ import org .opensearch .ml .common .transport .model_group .MLRegisterModelGroupResponse ;
50
50
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskAction ;
51
51
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskRequest ;
52
52
import org .opensearch .ml .common .transport .register .MLRegisterModelAction ;
63
63
import org .opensearch .ml .common .transport .training .MLTrainingTaskRequest ;
64
64
import org .opensearch .ml .common .transport .trainpredict .MLTrainAndPredictionTaskAction ;
65
65
66
- import java .io .IOException ;
67
- import java .time .Instant ;
68
- import java .util .Map ;
69
- import java .util .function .Function ;
70
-
71
- import static org .opensearch .ml .common .CommonValue .ML_MODEL_GROUP_INDEX ;
72
- import static org .opensearch .ml .common .input .Constants .ASYNC ;
73
- import static org .opensearch .ml .common .input .Constants .MODELID ;
74
- import static org .opensearch .ml .common .input .Constants .PREDICT ;
75
- import static org .opensearch .ml .common .input .Constants .TRAIN ;
76
- import static org .opensearch .ml .common .input .Constants .TRAINANDPREDICT ;
77
- import static org .opensearch .ml .common .input .InputHelper .convertArgumentToMLParameter ;
78
- import static org .opensearch .ml .common .input .InputHelper .getAction ;
79
- import static org .opensearch .ml .common .input .InputHelper .getFunctionName ;
66
+ import lombok .AccessLevel ;
67
+ import lombok .RequiredArgsConstructor ;
68
+ import lombok .experimental .FieldDefaults ;
80
69
81
70
@ FieldDefaults (makeFinal = true , level = AccessLevel .PRIVATE )
82
71
@ RequiredArgsConstructor
@@ -88,33 +77,32 @@ public class MachineLearningNodeClient implements MachineLearningClient {
88
77
public void predict (String modelId , MLInput mlInput , ActionListener <MLOutput > listener ) {
89
78
validateMLInput (mlInput , true );
90
79
91
- MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest .builder ()
92
- .mlInput (mlInput )
93
- .modelId (modelId )
94
- .dispatchTask (true )
95
- .build ();
80
+ MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
81
+ .builder ()
82
+ .mlInput (mlInput )
83
+ .modelId (modelId )
84
+ .dispatchTask (true )
85
+ .build ();
96
86
client .execute (MLPredictionTaskAction .INSTANCE , predictionRequest , getMlPredictionTaskResponseActionListener (listener ));
97
87
}
98
88
99
89
@ Override
100
90
public void trainAndPredict (MLInput mlInput , ActionListener <MLOutput > listener ) {
101
91
validateMLInput (mlInput , true );
102
- MLTrainingTaskRequest request = MLTrainingTaskRequest .builder ()
103
- .mlInput (mlInput )
104
- .dispatchTask (true )
105
- .build ();
92
+ MLTrainingTaskRequest request = MLTrainingTaskRequest .builder ().mlInput (mlInput ).dispatchTask (true ).build ();
106
93
107
94
client .execute (MLTrainAndPredictionTaskAction .INSTANCE , request , getMlPredictionTaskResponseActionListener (listener ));
108
95
}
109
96
110
97
@ Override
111
98
public void train (MLInput mlInput , boolean asyncTask , ActionListener <MLOutput > listener ) {
112
99
validateMLInput (mlInput , true );
113
- MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest .builder ()
114
- .mlInput (mlInput )
115
- .async (asyncTask )
116
- .dispatchTask (true )
117
- .build ();
100
+ MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest
101
+ .builder ()
102
+ .mlInput (mlInput )
103
+ .async (asyncTask )
104
+ .dispatchTask (true )
105
+ .build ();
118
106
119
107
client .execute (MLTrainingTaskAction .INSTANCE , trainingTaskRequest , getMlPredictionTaskResponseActionListener (listener ));
120
108
}
@@ -144,15 +132,13 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
144
132
trainAndPredict (mlInput , listener );
145
133
break ;
146
134
default :
147
- throw new IllegalArgumentException ("Unsupported action." );
135
+ throw new IllegalArgumentException ("Unsupported action." );
148
136
}
149
137
}
150
138
151
139
@ Override
152
140
public void getModel (String modelId , ActionListener <MLModel > listener ) {
153
- MLModelGetRequest mlModelGetRequest = MLModelGetRequest .builder ()
154
- .modelId (modelId )
155
- .build ();
141
+ MLModelGetRequest mlModelGetRequest = MLModelGetRequest .builder ().modelId (modelId ).build ();
156
142
157
143
client .execute (MLModelGetAction .INSTANCE , mlModelGetRequest , getMlGetModelResponseActionListener (listener ));
158
144
}
@@ -170,9 +156,7 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A
170
156
171
157
@ Override
172
158
public void deleteModel (String modelId , ActionListener <DeleteResponse > listener ) {
173
- MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest .builder ()
174
- .modelId (modelId )
175
- .build ();
159
+ MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest .builder ().modelId (modelId ).build ();
176
160
177
161
client .execute (MLModelDeleteAction .INSTANCE , mlModelDeleteRequest , ActionListener .wrap (deleteResponse -> {
178
162
listener .onResponse (deleteResponse );
@@ -181,17 +165,26 @@ public void deleteModel(String modelId, ActionListener<DeleteResponse> listener)
181
165
182
166
@ Override
183
167
public void searchModel (SearchRequest searchRequest , ActionListener <SearchResponse > listener ) {
184
- client .execute (MLModelSearchAction .INSTANCE , searchRequest , ActionListener .wrap (searchResponse -> {
185
- listener .onResponse (searchResponse );
186
- }, listener ::onFailure ));
168
+ client
169
+ .execute (
170
+ MLModelSearchAction .INSTANCE ,
171
+ searchRequest ,
172
+ ActionListener .wrap (searchResponse -> { listener .onResponse (searchResponse ); }, listener ::onFailure )
173
+ );
187
174
}
188
175
176
+ @ Override
177
+ public void registerModelGroup (
178
+ MLRegisterModelGroupInput mlRegisterModelGroupInput ,
179
+ ActionListener <MLRegisterModelGroupResponse > listener
180
+ ) {
181
+ MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest (mlRegisterModelGroupInput );
182
+ client .execute (MLRegisterModelGroupAction .INSTANCE , mlRegisterModelGroupRequest , listener );
183
+ }
189
184
190
185
@ Override
191
186
public void getTask (String taskId , ActionListener <MLTask > listener ) {
192
- MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest .builder ()
193
- .taskId (taskId )
194
- .build ();
187
+ MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest .builder ().taskId (taskId ).build ();
195
188
196
189
client .execute (MLTaskGetAction .INSTANCE , mlTaskGetRequest , ActionListener .wrap (response -> {
197
190
listener .onResponse (MLTaskGetResponse .fromActionResponse (response ).getMlTask ());
@@ -200,9 +193,7 @@ public void getTask(String taskId, ActionListener<MLTask> listener) {
200
193
201
194
@ Override
202
195
public void deleteTask (String taskId , ActionListener <DeleteResponse > listener ) {
203
- MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest .builder ()
204
- .taskId (taskId )
205
- .build ();
196
+ MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest .builder ().taskId (taskId ).build ();
206
197
207
198
client .execute (MLTaskDeleteAction .INSTANCE , mlTaskDeleteRequest , ActionListener .wrap (deleteResponse -> {
208
199
listener .onResponse (deleteResponse );
@@ -211,25 +202,34 @@ public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
211
202
212
203
@ Override
213
204
public void searchTask (SearchRequest searchRequest , ActionListener <SearchResponse > listener ) {
214
- client .execute (MLTaskSearchAction .INSTANCE , searchRequest , ActionListener .wrap (searchResponse -> {
215
- listener .onResponse (searchResponse );
216
- }, listener ::onFailure ));
205
+ client
206
+ .execute (
207
+ MLTaskSearchAction .INSTANCE ,
208
+ searchRequest ,
209
+ ActionListener .wrap (searchResponse -> { listener .onResponse (searchResponse ); }, listener ::onFailure )
210
+ );
217
211
}
218
212
219
213
@ Override
220
214
public void register (MLRegisterModelInput mlInput , ActionListener <MLRegisterModelResponse > listener ) {
221
215
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest (mlInput );
222
- client .execute (MLRegisterModelAction .INSTANCE , registerRequest , ActionListener .wrap (listener ::onResponse , e -> {
223
- listener .onFailure (e );
224
- }));
216
+ client
217
+ .execute (
218
+ MLRegisterModelAction .INSTANCE ,
219
+ registerRequest ,
220
+ ActionListener .wrap (listener ::onResponse , e -> { listener .onFailure (e ); })
221
+ );
225
222
}
226
223
227
224
@ Override
228
225
public void deploy (String modelId , ActionListener <MLDeployModelResponse > listener ) {
229
226
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest (modelId , false );
230
- client .execute (MLDeployModelAction .INSTANCE , deployModelRequest , ActionListener .wrap (listener ::onResponse , e -> {
231
- listener .onFailure (e );
232
- }));
227
+ client
228
+ .execute (
229
+ MLDeployModelAction .INSTANCE ,
230
+ deployModelRequest ,
231
+ ActionListener .wrap (listener ::onResponse , e -> { listener .onFailure (e ); })
232
+ );
233
233
}
234
234
235
235
@ Override
@@ -249,20 +249,22 @@ private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener
249
249
return actionListener ;
250
250
}
251
251
252
- private <T extends ActionResponse > ActionListener <T > wrapActionListener (final ActionListener <T > listener , final Function <ActionResponse , T > recreate ) {
253
- ActionListener <T > actionListener = ActionListener .wrap (r -> {
254
- listener .onResponse (recreate .apply (r ));;
255
- }, e ->{
256
- listener .onFailure (e );
257
- });
252
+ private <T extends ActionResponse > ActionListener <T > wrapActionListener (
253
+ final ActionListener <T > listener ,
254
+ final Function <ActionResponse , T > recreate
255
+ ) {
256
+ ActionListener <T > actionListener = ActionListener .wrap (r -> {
257
+ listener .onResponse (recreate .apply (r ));
258
+ ;
259
+ }, e -> { listener .onFailure (e ); });
258
260
return actionListener ;
259
261
}
260
262
261
263
private void validateMLInput (MLInput mlInput , boolean requireInput ) {
262
264
if (mlInput == null ) {
263
265
throw new IllegalArgumentException ("ML Input can't be null" );
264
266
}
265
- if (requireInput && mlInput .getInputDataset () == null ) {
267
+ if (requireInput && mlInput .getInputDataset () == null ) {
266
268
throw new IllegalArgumentException ("input data set can't be null" );
267
269
}
268
270
}
0 commit comments