11
11
import static org .opensearch .ml .common .MLTask .REMOTE_JOB_FIELD ;
12
12
import static org .opensearch .ml .common .MLTask .STATE_FIELD ;
13
13
import static org .opensearch .ml .common .MLTaskState .CANCELLED ;
14
+ import static org .opensearch .ml .common .MLTaskState .CANCELLING ;
14
15
import static org .opensearch .ml .common .MLTaskState .COMPLETED ;
16
+ import static org .opensearch .ml .common .MLTaskState .EXPIRED ;
15
17
import static org .opensearch .ml .common .connector .ConnectorAction .ActionType .BATCH_PREDICT_STATUS ;
18
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX ;
19
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX ;
20
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX ;
21
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX ;
22
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_FIELD ;
16
23
import static org .opensearch .ml .utils .MLExceptionUtils .logException ;
17
24
import static org .opensearch .ml .utils .MLNodeUtils .createXContentParserFromRegistry ;
18
25
19
26
import java .util .HashMap ;
27
+ import java .util .List ;
20
28
import java .util .Map ;
21
29
import java .util .Optional ;
30
+ import java .util .function .Consumer ;
31
+ import java .util .regex .Matcher ;
32
+ import java .util .regex .Pattern ;
22
33
23
34
import org .opensearch .OpenSearchException ;
24
35
import org .opensearch .OpenSearchStatusException ;
30
41
import org .opensearch .client .Client ;
31
42
import org .opensearch .cluster .service .ClusterService ;
32
43
import org .opensearch .common .inject .Inject ;
44
+ import org .opensearch .common .settings .Setting ;
45
+ import org .opensearch .common .settings .Settings ;
33
46
import org .opensearch .common .util .concurrent .ThreadContext ;
34
47
import org .opensearch .core .action .ActionListener ;
35
48
import org .opensearch .core .rest .RestStatus ;
@@ -80,6 +93,12 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
80
93
MLTaskManager mlTaskManager ;
81
94
MLModelCacheHelper modelCacheHelper ;
82
95
96
+ volatile List <String > remoteJobStatusFields ;
97
+ volatile Pattern remoteJobCompletedStatusRegexPattern ;
98
+ volatile Pattern remoteJobCancelledStatusRegexPattern ;
99
+ volatile Pattern remoteJobCancellingStatusRegexPattern ;
100
+ volatile Pattern remoteJobExpiredStatusRegexPattern ;
101
+
83
102
@ Inject
84
103
public GetTaskTransportAction (
85
104
TransportService transportService ,
@@ -91,7 +110,8 @@ public GetTaskTransportAction(
91
110
ConnectorAccessControlHelper connectorAccessControlHelper ,
92
111
EncryptorImpl encryptor ,
93
112
MLTaskManager mlTaskManager ,
94
- MLModelManager mlModelManager
113
+ MLModelManager mlModelManager ,
114
+ Settings settings
95
115
) {
96
116
super (MLTaskGetAction .NAME , transportService , actionFilters , MLTaskGetRequest ::new );
97
117
this .client = client ;
@@ -102,6 +122,44 @@ public GetTaskTransportAction(
102
122
this .encryptor = encryptor ;
103
123
this .mlTaskManager = mlTaskManager ;
104
124
this .mlModelManager = mlModelManager ;
125
+
126
+ remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD .get (settings );
127
+ clusterService .getClusterSettings ().addSettingsUpdateConsumer (ML_COMMONS_REMOTE_JOB_STATUS_FIELD , it -> remoteJobStatusFields = it );
128
+ initializeRegexPattern (
129
+ ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX ,
130
+ settings ,
131
+ clusterService ,
132
+ (regex ) -> remoteJobCompletedStatusRegexPattern = Pattern .compile (regex , Pattern .CASE_INSENSITIVE )
133
+ );
134
+ initializeRegexPattern (
135
+ ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX ,
136
+ settings ,
137
+ clusterService ,
138
+ (regex ) -> remoteJobCancelledStatusRegexPattern = Pattern .compile (regex , Pattern .CASE_INSENSITIVE )
139
+ );
140
+ initializeRegexPattern (
141
+ ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX ,
142
+ settings ,
143
+ clusterService ,
144
+ (regex ) -> remoteJobCancellingStatusRegexPattern = Pattern .compile (regex , Pattern .CASE_INSENSITIVE )
145
+ );
146
+ initializeRegexPattern (
147
+ ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX ,
148
+ settings ,
149
+ clusterService ,
150
+ (regex ) -> remoteJobExpiredStatusRegexPattern = Pattern .compile (regex , Pattern .CASE_INSENSITIVE )
151
+ );
152
+ }
153
+
154
+ private void initializeRegexPattern (
155
+ Setting <String > setting ,
156
+ Settings settings ,
157
+ ClusterService clusterService ,
158
+ Consumer <String > patternInitializer
159
+ ) {
160
+ String regex = setting .get (settings );
161
+ patternInitializer .accept (regex );
162
+ clusterService .getClusterSettings ().addSettingsUpdateConsumer (setting , it -> patternInitializer .accept (it ));
105
163
}
106
164
107
165
@ Override
@@ -210,7 +268,7 @@ private void executeConnector(
210
268
MLInput mlInput ,
211
269
String taskId ,
212
270
MLTask mlTask ,
213
- Map <String , Object > transformJob ,
271
+ Map <String , Object > remoteJob ,
214
272
ActionListener <MLTaskGetResponse > actionListener
215
273
) {
216
274
if (connectorAccessControlHelper .validateConnectorAccess (client , connector )) {
@@ -222,15 +280,15 @@ private void executeConnector(
222
280
connectorExecutor .setClient (client );
223
281
connectorExecutor .setXContentRegistry (xContentRegistry );
224
282
connectorExecutor .executeAction (BATCH_PREDICT_STATUS .name (), mlInput , ActionListener .wrap (taskResponse -> {
225
- processTaskResponse (mlTask , taskId , taskResponse , transformJob , actionListener );
283
+ processTaskResponse (mlTask , taskId , taskResponse , remoteJob , actionListener );
226
284
}, e -> { actionListener .onFailure (e ); }));
227
285
} else {
228
286
actionListener
229
287
.onFailure (new OpenSearchStatusException ("You don't have permission to access this connector" , RestStatus .FORBIDDEN ));
230
288
}
231
289
}
232
290
233
- private void processTaskResponse (
291
+ protected void processTaskResponse (
234
292
MLTask mlTask ,
235
293
String taskId ,
236
294
MLTaskResponse taskResponse ,
@@ -248,15 +306,11 @@ private void processTaskResponse(
248
306
Map <String , Object > updatedTask = new HashMap <>();
249
307
updatedTask .put (REMOTE_JOB_FIELD , remoteJob );
250
308
251
- if ((remoteJob .containsKey ("status" ) && remoteJob .get ("status" ).equals ("completed" ))
252
- || (remoteJob .containsKey ("TransformJobStatus" ) && remoteJob .get ("TransformJobStatus" ).equals ("Completed" ))) {
253
- updatedTask .put (STATE_FIELD , COMPLETED );
254
- mlTask .setState (COMPLETED );
255
-
256
- } else if ((remoteJob .containsKey ("status" ) && remoteJob .get ("status" ).equals ("cancelled" ))
257
- || (remoteJob .containsKey ("TransformJobStatus" ) && remoteJob .get ("TransformJobStatus" ).equals ("Stopped" ))) {
258
- updatedTask .put (STATE_FIELD , CANCELLED );
259
- mlTask .setState (CANCELLED );
309
+ for (String statusField : remoteJobStatusFields ) {
310
+ String statusValue = String .valueOf (remoteJob .get (statusField ));
311
+ if (remoteJob .containsKey (statusField )) {
312
+ updateTaskState (updatedTask , mlTask , statusValue );
313
+ }
260
314
}
261
315
mlTaskManager .updateMLTaskDirectly (taskId , updatedTask , ActionListener .wrap (response -> {
262
316
actionListener .onResponse (MLTaskGetResponse .builder ().mlTask (mlTask ).build ());
@@ -280,4 +334,25 @@ private void processTaskResponse(
280
334
log .error ("Unable to fetch status for ml task " , e );
281
335
}
282
336
}
337
+
338
+ private void updateTaskState (Map <String , Object > updatedTask , MLTask mlTask , String statusValue ) {
339
+ if (matchesPattern (remoteJobCancellingStatusRegexPattern , statusValue )) {
340
+ updatedTask .put (STATE_FIELD , CANCELLING );
341
+ mlTask .setState (CANCELLING );
342
+ } else if (matchesPattern (remoteJobCancelledStatusRegexPattern , statusValue )) {
343
+ updatedTask .put (STATE_FIELD , CANCELLED );
344
+ mlTask .setState (CANCELLED );
345
+ } else if (matchesPattern (remoteJobCompletedStatusRegexPattern , statusValue )) {
346
+ updatedTask .put (STATE_FIELD , COMPLETED );
347
+ mlTask .setState (COMPLETED );
348
+ } else if (matchesPattern (remoteJobExpiredStatusRegexPattern , statusValue )) {
349
+ updatedTask .put (STATE_FIELD , EXPIRED );
350
+ mlTask .setState (EXPIRED );
351
+ }
352
+ }
353
+
354
+ private boolean matchesPattern (Pattern pattern , String input ) {
355
+ Matcher matcher = pattern .matcher (input );
356
+ return matcher .find ();
357
+ }
283
358
}
0 commit comments