20
20
import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX ;
21
21
import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX ;
22
22
import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_JOB_STATUS_FIELD ;
23
+ import static org .opensearch .ml .utils .MLExceptionUtils .BATCH_INFERENCE_DISABLED_ERR_MSG ;
23
24
import static org .opensearch .ml .utils .MLExceptionUtils .logException ;
24
25
import static org .opensearch .ml .utils .MLNodeUtils .createXContentParserFromRegistry ;
25
26
68
69
import org .opensearch .ml .engine .algorithms .remote .RemoteConnectorExecutor ;
69
70
import org .opensearch .ml .engine .encryptor .EncryptorImpl ;
70
71
import org .opensearch .ml .helper .ConnectorAccessControlHelper ;
71
- import org .opensearch .ml .model .MLModelCacheHelper ;
72
72
import org .opensearch .ml .model .MLModelManager ;
73
+ import org .opensearch .ml .settings .MLFeatureEnabledSetting ;
73
74
import org .opensearch .ml .task .MLTaskManager ;
74
75
import org .opensearch .script .ScriptService ;
75
76
import org .opensearch .tasks .Task ;
@@ -91,7 +92,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
91
92
MLModelManager mlModelManager ;
92
93
93
94
MLTaskManager mlTaskManager ;
94
- MLModelCacheHelper modelCacheHelper ;
95
+ private MLFeatureEnabledSetting mlFeatureEnabledSetting ;
95
96
96
97
volatile List <String > remoteJobStatusFields ;
97
98
volatile Pattern remoteJobCompletedStatusRegexPattern ;
@@ -111,6 +112,7 @@ public GetTaskTransportAction(
111
112
EncryptorImpl encryptor ,
112
113
MLTaskManager mlTaskManager ,
113
114
MLModelManager mlModelManager ,
115
+ MLFeatureEnabledSetting mlFeatureEnabledSetting ,
114
116
Settings settings
115
117
) {
116
118
super (MLTaskGetAction .NAME , transportService , actionFilters , MLTaskGetRequest ::new );
@@ -122,6 +124,7 @@ public GetTaskTransportAction(
122
124
this .encryptor = encryptor ;
123
125
this .mlTaskManager = mlTaskManager ;
124
126
this .mlModelManager = mlModelManager ;
127
+ this .mlFeatureEnabledSetting = mlFeatureEnabledSetting ;
125
128
126
129
remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD .get (settings );
127
130
clusterService .getClusterSettings ().addSettingsUpdateConsumer (ML_COMMONS_REMOTE_JOB_STATUS_FIELD , it -> remoteJobStatusFields = it );
@@ -178,6 +181,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
178
181
MLTask mlTask = MLTask .parse (parser );
179
182
180
183
// check if function is remote and task is of type batch prediction
184
+ if (mlTask .getTaskType () == MLTaskType .BATCH_PREDICTION
185
+ && !mlFeatureEnabledSetting .isOfflineBatchInferenceEnabled ()) {
186
+ throw new IllegalStateException (BATCH_INFERENCE_DISABLED_ERR_MSG );
187
+ }
181
188
if (mlTask .getTaskType () == MLTaskType .BATCH_PREDICTION && mlTask .getFunctionName () == FunctionName .REMOTE ) {
182
189
processRemoteBatchPrediction (mlTask , taskId , actionListener );
183
190
} else {
0 commit comments