15
15
16
16
import java .time .Instant ;
17
17
import java .util .List ;
18
- import java .util .Locale ;
19
18
import java .util .Map ;
20
19
import java .util .regex .Pattern ;
21
20
import java .util .stream .Collectors ;
37
36
import org .opensearch .ml .common .transport .batch .MLBatchIngestionResponse ;
38
37
import org .opensearch .ml .engine .MLEngineClassLoader ;
39
38
import org .opensearch .ml .engine .ingest .Ingestable ;
39
+ import org .opensearch .ml .model .MLModelManager ;
40
40
import org .opensearch .ml .settings .MLFeatureEnabledSetting ;
41
41
import org .opensearch .ml .task .MLTaskManager ;
42
42
import org .opensearch .ml .utils .MLExceptionUtils ;
@@ -56,6 +56,7 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
56
56
public static final String SOURCE = "source" ;
57
57
TransportService transportService ;
58
58
MLTaskManager mlTaskManager ;
59
+ MLModelManager mlModelManager ;
59
60
private final Client client ;
60
61
private ThreadPool threadPool ;
61
62
private MLFeatureEnabledSetting mlFeatureEnabledSetting ;
@@ -67,13 +68,15 @@ public TransportBatchIngestionAction(
67
68
Client client ,
68
69
MLTaskManager mlTaskManager ,
69
70
ThreadPool threadPool ,
71
+ MLModelManager mlModelManager ,
70
72
MLFeatureEnabledSetting mlFeatureEnabledSetting
71
73
) {
72
74
super (MLBatchIngestionAction .NAME , transportService , actionFilters , MLBatchIngestionRequest ::new );
73
75
this .transportService = transportService ;
74
76
this .client = client ;
75
77
this .mlTaskManager = mlTaskManager ;
76
78
this .threadPool = threadPool ;
79
+ this .mlModelManager = mlModelManager ;
77
80
this .mlFeatureEnabledSetting = mlFeatureEnabledSetting ;
78
81
}
79
82
@@ -86,44 +89,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
86
89
throw new IllegalStateException (OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG );
87
90
}
88
91
validateBatchIngestInput (mlBatchIngestionInput );
89
- MLTask mlTask = MLTask
90
- .builder ()
91
- .async (true )
92
- .taskType (MLTaskType .BATCH_INGEST )
93
- .createTime (Instant .now ())
94
- .lastUpdateTime (Instant .now ())
95
- .state (MLTaskState .CREATED )
96
- .build ();
97
-
98
- mlTaskManager .createMLTask (mlTask , ActionListener .wrap (response -> {
99
- String taskId = response .getId ();
100
- try {
101
- mlTask .setTaskId (taskId );
102
- mlTaskManager .add (mlTask );
103
- listener .onResponse (new MLBatchIngestionResponse (taskId , MLTaskType .BATCH_INGEST , MLTaskState .CREATED .name ()));
104
- String ingestType = (String ) mlBatchIngestionInput .getDataSources ().get (TYPE );
105
- Ingestable ingestable = MLEngineClassLoader .initInstance (ingestType .toLowerCase (Locale .ROOT ), client , Client .class );
106
- threadPool .executor (INGEST_THREAD_POOL ).execute (() -> {
107
- executeWithErrorHandling (() -> {
108
- double successRate = ingestable .ingest (mlBatchIngestionInput );
109
- handleSuccessRate (successRate , taskId );
110
- }, taskId );
111
- });
112
- } catch (Exception ex ) {
113
- log .error ("Failed in batch ingestion" , ex );
114
- mlTaskManager
115
- .updateMLTask (
116
- taskId ,
117
- Map .of (STATE_FIELD , FAILED , ERROR_FIELD , MLExceptionUtils .getRootCauseMessage (ex )),
118
- TASK_SEMAPHORE_TIMEOUT ,
119
- true
92
+ if (mlBatchIngestionInput .getConnectorId () != null
93
+ && (mlBatchIngestionInput .getCredential () == null || mlBatchIngestionInput .getCredential ().isEmpty ())) {
94
+ mlModelManager .getConnectorCredential (mlBatchIngestionInput .getConnectorId (), ActionListener .wrap (credentialMap -> {
95
+ mlBatchIngestionInput .setCredential (credentialMap );
96
+ createMLTaskandExecute (mlBatchIngestionInput , listener );
97
+ }, e -> {
98
+ log .error (e .getMessage ());
99
+ listener
100
+ .onFailure (
101
+ new OpenSearchStatusException (
102
+ "Fail to fetch credentials from the connector in the batch ingestion input: " + e .getMessage (),
103
+ RestStatus .BAD_REQUEST
104
+ )
120
105
);
121
- listener .onFailure (ex );
122
- }
123
- }, exception -> {
124
- log .error ("Failed to create batch ingestion task" , exception );
125
- listener .onFailure (exception );
126
- }));
106
+ }));
107
+ } else {
108
+ createMLTaskandExecute (mlBatchIngestionInput , listener );
109
+ }
127
110
} catch (IllegalArgumentException e ) {
128
111
log .error (e .getMessage ());
129
112
listener
@@ -138,6 +121,47 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
138
121
}
139
122
}
140
123
124
+ protected void createMLTaskandExecute (MLBatchIngestionInput mlBatchIngestionInput , ActionListener <MLBatchIngestionResponse > listener ) {
125
+ MLTask mlTask = MLTask
126
+ .builder ()
127
+ .async (true )
128
+ .taskType (MLTaskType .BATCH_INGEST )
129
+ .createTime (Instant .now ())
130
+ .lastUpdateTime (Instant .now ())
131
+ .state (MLTaskState .CREATED )
132
+ .build ();
133
+
134
+ mlTaskManager .createMLTask (mlTask , ActionListener .wrap (response -> {
135
+ String taskId = response .getId ();
136
+ try {
137
+ mlTask .setTaskId (taskId );
138
+ mlTaskManager .add (mlTask );
139
+ listener .onResponse (new MLBatchIngestionResponse (taskId , MLTaskType .BATCH_INGEST , MLTaskState .CREATED .name ()));
140
+ String ingestType = (String ) mlBatchIngestionInput .getDataSources ().get (TYPE );
141
+ Ingestable ingestable = MLEngineClassLoader .initInstance (ingestType .toLowerCase (), client , Client .class );
142
+ threadPool .executor (INGEST_THREAD_POOL ).execute (() -> {
143
+ executeWithErrorHandling (() -> {
144
+ double successRate = ingestable .ingest (mlBatchIngestionInput );
145
+ handleSuccessRate (successRate , taskId );
146
+ }, taskId );
147
+ });
148
+ } catch (Exception ex ) {
149
+ log .error ("Failed in batch ingestion" , ex );
150
+ mlTaskManager
151
+ .updateMLTask (
152
+ taskId ,
153
+ Map .of (STATE_FIELD , FAILED , ERROR_FIELD , MLExceptionUtils .getRootCauseMessage (ex )),
154
+ TASK_SEMAPHORE_TIMEOUT ,
155
+ true
156
+ );
157
+ listener .onFailure (ex );
158
+ }
159
+ }, exception -> {
160
+ log .error ("Failed to create batch ingestion task" , exception );
161
+ listener .onFailure (exception );
162
+ }));
163
+ }
164
+
141
165
protected void executeWithErrorHandling (Runnable task , String taskId ) {
142
166
try {
143
167
task .run ();
@@ -190,6 +214,9 @@ private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInpu
190
214
|| mlBatchIngestionInput .getDataSources ().isEmpty ()) {
191
215
throw new IllegalArgumentException ("The batch ingest input data source cannot be null" );
192
216
}
217
+ if (mlBatchIngestionInput .getCredential () == null && mlBatchIngestionInput .getConnectorId () == null ) {
218
+ throw new IllegalArgumentException ("The batch ingest credential or connector_id cannot be null" );
219
+ }
193
220
Map <String , Object > dataSources = mlBatchIngestionInput .getDataSources ();
194
221
if (dataSources .get (TYPE ) == null || dataSources .get (SOURCE ) == null ) {
195
222
throw new IllegalArgumentException ("The batch ingest input data source is missing data type or source" );
0 commit comments