36
36
import org .opensearch .ml .common .transport .batch .MLBatchIngestionResponse ;
37
37
import org .opensearch .ml .engine .MLEngineClassLoader ;
38
38
import org .opensearch .ml .engine .ingest .Ingestable ;
39
+ import org .opensearch .ml .model .MLModelManager ;
39
40
import org .opensearch .ml .settings .MLFeatureEnabledSetting ;
40
41
import org .opensearch .ml .task .MLTaskManager ;
41
42
import org .opensearch .ml .utils .MLExceptionUtils ;
@@ -55,6 +56,7 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
55
56
public static final String SOURCE = "source" ;
56
57
TransportService transportService ;
57
58
MLTaskManager mlTaskManager ;
59
+ MLModelManager mlModelManager ;
58
60
private final Client client ;
59
61
private ThreadPool threadPool ;
60
62
private MLFeatureEnabledSetting mlFeatureEnabledSetting ;
@@ -66,13 +68,15 @@ public TransportBatchIngestionAction(
66
68
Client client ,
67
69
MLTaskManager mlTaskManager ,
68
70
ThreadPool threadPool ,
71
+ MLModelManager mlModelManager ,
69
72
MLFeatureEnabledSetting mlFeatureEnabledSetting
70
73
) {
71
74
super (MLBatchIngestionAction .NAME , transportService , actionFilters , MLBatchIngestionRequest ::new );
72
75
this .transportService = transportService ;
73
76
this .client = client ;
74
77
this .mlTaskManager = mlTaskManager ;
75
78
this .threadPool = threadPool ;
79
+ this .mlModelManager = mlModelManager ;
76
80
this .mlFeatureEnabledSetting = mlFeatureEnabledSetting ;
77
81
}
78
82
@@ -85,44 +89,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
85
89
throw new IllegalStateException (OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG );
86
90
}
87
91
validateBatchIngestInput (mlBatchIngestionInput );
88
- MLTask mlTask = MLTask
89
- .builder ()
90
- .async (true )
91
- .taskType (MLTaskType .BATCH_INGEST )
92
- .createTime (Instant .now ())
93
- .lastUpdateTime (Instant .now ())
94
- .state (MLTaskState .CREATED )
95
- .build ();
96
-
97
- mlTaskManager .createMLTask (mlTask , ActionListener .wrap (response -> {
98
- String taskId = response .getId ();
99
- try {
100
- mlTask .setTaskId (taskId );
101
- mlTaskManager .add (mlTask );
102
- listener .onResponse (new MLBatchIngestionResponse (taskId , MLTaskType .BATCH_INGEST , MLTaskState .CREATED .name ()));
103
- String ingestType = (String ) mlBatchIngestionInput .getDataSources ().get (TYPE );
104
- Ingestable ingestable = MLEngineClassLoader .initInstance (ingestType .toLowerCase (), client , Client .class );
105
- threadPool .executor (INGEST_THREAD_POOL ).execute (() -> {
106
- executeWithErrorHandling (() -> {
107
- double successRate = ingestable .ingest (mlBatchIngestionInput );
108
- handleSuccessRate (successRate , taskId );
109
- }, taskId );
110
- });
111
- } catch (Exception ex ) {
112
- log .error ("Failed in batch ingestion" , ex );
113
- mlTaskManager
114
- .updateMLTask (
115
- taskId ,
116
- Map .of (STATE_FIELD , FAILED , ERROR_FIELD , MLExceptionUtils .getRootCauseMessage (ex )),
117
- TASK_SEMAPHORE_TIMEOUT ,
118
- true
92
+
93
+ if (mlBatchIngestionInput .getConnectorId () != null && mlBatchIngestionInput .getCredential () == null ) {
94
+ mlModelManager .getConnectorCredential (mlBatchIngestionInput .getConnectorId (), ActionListener .wrap (credentialMap -> {
95
+ mlBatchIngestionInput .setCredential (credentialMap );
96
+ createMLTaskandExecute (mlBatchIngestionInput , listener );
97
+ }, e -> {
98
+ log .error (e .getMessage ());
99
+ listener
100
+ .onFailure (
101
+ new OpenSearchStatusException (
102
+ "Fail to fetch credentials from the connector in the batch ingestion input: " + e .getMessage (),
103
+ RestStatus .BAD_REQUEST
104
+ )
119
105
);
120
- listener .onFailure (ex );
121
- }
122
- }, exception -> {
123
- log .error ("Failed to create batch ingestion task" , exception );
124
- listener .onFailure (exception );
125
- }));
106
+ }));
107
+ }
108
+
109
+ createMLTaskandExecute (mlBatchIngestionInput , listener );
126
110
} catch (IllegalArgumentException e ) {
127
111
log .error (e .getMessage ());
128
112
listener
@@ -137,6 +121,47 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
137
121
}
138
122
}
139
123
124
+ protected void createMLTaskandExecute (MLBatchIngestionInput mlBatchIngestionInput , ActionListener <MLBatchIngestionResponse > listener ) {
125
+ MLTask mlTask = MLTask
126
+ .builder ()
127
+ .async (true )
128
+ .taskType (MLTaskType .BATCH_INGEST )
129
+ .createTime (Instant .now ())
130
+ .lastUpdateTime (Instant .now ())
131
+ .state (MLTaskState .CREATED )
132
+ .build ();
133
+
134
+ mlTaskManager .createMLTask (mlTask , ActionListener .wrap (response -> {
135
+ String taskId = response .getId ();
136
+ try {
137
+ mlTask .setTaskId (taskId );
138
+ mlTaskManager .add (mlTask );
139
+ listener .onResponse (new MLBatchIngestionResponse (taskId , MLTaskType .BATCH_INGEST , MLTaskState .CREATED .name ()));
140
+ String ingestType = (String ) mlBatchIngestionInput .getDataSources ().get (TYPE );
141
+ Ingestable ingestable = MLEngineClassLoader .initInstance (ingestType .toLowerCase (), client , Client .class );
142
+ threadPool .executor (INGEST_THREAD_POOL ).execute (() -> {
143
+ executeWithErrorHandling (() -> {
144
+ double successRate = ingestable .ingest (mlBatchIngestionInput );
145
+ handleSuccessRate (successRate , taskId );
146
+ }, taskId );
147
+ });
148
+ } catch (Exception ex ) {
149
+ log .error ("Failed in batch ingestion" , ex );
150
+ mlTaskManager
151
+ .updateMLTask (
152
+ taskId ,
153
+ Map .of (STATE_FIELD , FAILED , ERROR_FIELD , MLExceptionUtils .getRootCauseMessage (ex )),
154
+ TASK_SEMAPHORE_TIMEOUT ,
155
+ true
156
+ );
157
+ listener .onFailure (ex );
158
+ }
159
+ }, exception -> {
160
+ log .error ("Failed to create batch ingestion task" , exception );
161
+ listener .onFailure (exception );
162
+ }));
163
+ }
164
+
140
165
protected void executeWithErrorHandling (Runnable task , String taskId ) {
141
166
try {
142
167
task .run ();
@@ -189,6 +214,9 @@ private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInpu
189
214
|| mlBatchIngestionInput .getDataSources ().isEmpty ()) {
190
215
throw new IllegalArgumentException ("The batch ingest input data source cannot be null" );
191
216
}
217
+ if (mlBatchIngestionInput .getCredential () == null && mlBatchIngestionInput .getConnectorId () == null ) {
218
+ throw new IllegalArgumentException ("The batch ingest credential or connector_id cannot be null" );
219
+ }
192
220
Map <String , Object > dataSources = mlBatchIngestionInput .getDataSources ();
193
221
if (dataSources .get (TYPE ) == null || dataSources .get (SOURCE ) == null ) {
194
222
throw new IllegalArgumentException ("The batch ingest input data source is missing data type or source" );
0 commit comments