Skip to content

Commit ebf2b75

Browse files
committed
use connector credential in offline batch ingestion
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent ff6fe67 commit ebf2b75

File tree

6 files changed

+105
-45
lines changed

6 files changed

+105
-45
lines changed

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

+2
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,6 @@ default void validateConnectorURL(List<String> urlRegexes) {
184184
}
185185

186186
Map<String, String> getDecryptedHeaders();
187+
188+
Map<String, String> getDecryptedCredential();
187189
}

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+18-7
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,37 @@
2020

2121
import lombok.Builder;
2222
import lombok.Getter;
23+
import lombok.Setter;
2324

2425
/**
2526
* ML batch ingestion data: index, field mapping and input and out files.
2627
*/
28+
@Getter
2729
public class MLBatchIngestionInput implements ToXContentObject, Writeable {
2830

2931
public static final String INDEX_NAME_FIELD = "index_name";
3032
public static final String FIELD_MAP_FIELD = "field_map";
3133
public static final String INGEST_FIELDS = "ingest_fields";
3234
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
3335
public static final String DATA_SOURCE_FIELD = "data_source";
36+
public static final String CONNECTOR_ID_FIELD = "connector_id";
3437

35-
@Getter
3638
private String indexName;
37-
@Getter
3839
private Map<String, Object> fieldMapping;
39-
@Getter
4040
private String[] ingestFields;
41-
@Getter
4241
private Map<String, Object> dataSources;
43-
@Getter
42+
@Setter
4443
private Map<String, String> credential;
44+
private String connectorId;
4545

4646
@Builder(toBuilder = true)
4747
public MLBatchIngestionInput(
4848
String indexName,
4949
Map<String, Object> fieldMapping,
5050
String[] ingestFields,
5151
Map<String, Object> dataSources,
52-
Map<String, String> credential
52+
Map<String, String> credential,
53+
String connectorId
5354
) {
5455
if (indexName == null) {
5556
throw new IllegalArgumentException(
@@ -66,6 +67,7 @@ public MLBatchIngestionInput(
6667
this.ingestFields = ingestFields;
6768
this.dataSources = dataSources;
6869
this.credential = credential;
70+
this.connectorId = connectorId;
6971
}
7072

7173
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
@@ -74,6 +76,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
7476
String[] ingestFields = null;
7577
Map<String, Object> dataSources = null;
7678
Map<String, String> credential = new HashMap<>();
79+
String connectorId = null;
7780

7881
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
7982
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -93,6 +96,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
9396
case CONNECTOR_CREDENTIAL_FIELD:
9497
credential = parser.mapStrings();
9598
break;
99+
case CONNECTOR_ID_FIELD:
100+
connectorId = parser.text();
101+
break;
96102
case DATA_SOURCE_FIELD:
97103
dataSources = parser.map();
98104
break;
@@ -101,7 +107,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
101107
break;
102108
}
103109
}
104-
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
110+
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential, connectorId);
105111
}
106112

107113
@Override
@@ -119,6 +125,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
119125
if (credential != null) {
120126
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
121127
}
128+
if (connectorId != null) {
129+
builder.field(CONNECTOR_ID_FIELD, connectorId);
130+
}
122131
if (dataSources != null) {
123132
builder.field(DATA_SOURCE_FIELD, dataSources);
124133
}
@@ -147,6 +156,7 @@ public void writeTo(StreamOutput output) throws IOException {
147156
} else {
148157
output.writeBoolean(false);
149158
}
159+
output.writeOptionalString(connectorId);
150160
if (dataSources != null) {
151161
output.writeBoolean(true);
152162
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
@@ -166,6 +176,7 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
166176
if (input.readBoolean()) {
167177
credential = input.readMap(s -> s.readString(), s -> s.readString());
168178
}
179+
this.connectorId = input.readOptionalString();
169180
if (input.readBoolean()) {
170181
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
171182
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

+8
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
package org.opensearch.ml.engine;
77

8+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
9+
810
import java.nio.file.Path;
911
import java.util.Locale;
1012
import java.util.Map;
1113

1214
import org.opensearch.core.action.ActionListener;
1315
import org.opensearch.ml.common.FunctionName;
1416
import org.opensearch.ml.common.MLModel;
17+
import org.opensearch.ml.common.connector.Connector;
1518
import org.opensearch.ml.common.dataframe.DataFrame;
1619
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
1720
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -120,6 +123,11 @@ public MLModel train(Input input) {
120123
return trainable.train(mlInput);
121124
}
122125

126+
public Map<String, String> getConnectorCredential(Connector connector) {
127+
connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
128+
return connector.getDecryptedCredential();
129+
}
130+
123131
public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
124132
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
125133
predictable.initModel(mlModel, params, encryptor);

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+65-38
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import java.time.Instant;
1717
import java.util.List;
18-
import java.util.Locale;
1918
import java.util.Map;
2019
import java.util.regex.Pattern;
2120
import java.util.stream.Collectors;
@@ -37,6 +36,7 @@
3736
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
3837
import org.opensearch.ml.engine.MLEngineClassLoader;
3938
import org.opensearch.ml.engine.ingest.Ingestable;
39+
import org.opensearch.ml.model.MLModelManager;
4040
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
4141
import org.opensearch.ml.task.MLTaskManager;
4242
import org.opensearch.ml.utils.MLExceptionUtils;
@@ -56,6 +56,7 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
5656
public static final String SOURCE = "source";
5757
TransportService transportService;
5858
MLTaskManager mlTaskManager;
59+
MLModelManager mlModelManager;
5960
private final Client client;
6061
private ThreadPool threadPool;
6162
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
@@ -67,13 +68,15 @@ public TransportBatchIngestionAction(
6768
Client client,
6869
MLTaskManager mlTaskManager,
6970
ThreadPool threadPool,
71+
MLModelManager mlModelManager,
7072
MLFeatureEnabledSetting mlFeatureEnabledSetting
7173
) {
7274
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
7375
this.transportService = transportService;
7476
this.client = client;
7577
this.mlTaskManager = mlTaskManager;
7678
this.threadPool = threadPool;
79+
this.mlModelManager = mlModelManager;
7780
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
7881
}
7982

@@ -86,44 +89,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
8689
throw new IllegalStateException(OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG);
8790
}
8891
validateBatchIngestInput(mlBatchIngestionInput);
89-
MLTask mlTask = MLTask
90-
.builder()
91-
.async(true)
92-
.taskType(MLTaskType.BATCH_INGEST)
93-
.createTime(Instant.now())
94-
.lastUpdateTime(Instant.now())
95-
.state(MLTaskState.CREATED)
96-
.build();
97-
98-
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
99-
String taskId = response.getId();
100-
try {
101-
mlTask.setTaskId(taskId);
102-
mlTaskManager.add(mlTask);
103-
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
104-
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
105-
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(Locale.ROOT), client, Client.class);
106-
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
107-
executeWithErrorHandling(() -> {
108-
double successRate = ingestable.ingest(mlBatchIngestionInput);
109-
handleSuccessRate(successRate, taskId);
110-
}, taskId);
111-
});
112-
} catch (Exception ex) {
113-
log.error("Failed in batch ingestion", ex);
114-
mlTaskManager
115-
.updateMLTask(
116-
taskId,
117-
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
118-
TASK_SEMAPHORE_TIMEOUT,
119-
true
92+
if (mlBatchIngestionInput.getConnectorId() != null
93+
&& (mlBatchIngestionInput.getCredential() == null || mlBatchIngestionInput.getCredential().isEmpty())) {
94+
mlModelManager.getConnectorCredential(mlBatchIngestionInput.getConnectorId(), ActionListener.wrap(credentialMap -> {
95+
mlBatchIngestionInput.setCredential(credentialMap);
96+
createMLTaskandExecute(mlBatchIngestionInput, listener);
97+
}, e -> {
98+
log.error(e.getMessage());
99+
listener
100+
.onFailure(
101+
new OpenSearchStatusException(
102+
"Fail to fetch credentials from the connector in the batch ingestion input: " + e.getMessage(),
103+
RestStatus.BAD_REQUEST
104+
)
120105
);
121-
listener.onFailure(ex);
122-
}
123-
}, exception -> {
124-
log.error("Failed to create batch ingestion task", exception);
125-
listener.onFailure(exception);
126-
}));
106+
}));
107+
} else {
108+
createMLTaskandExecute(mlBatchIngestionInput, listener);
109+
}
127110
} catch (IllegalArgumentException e) {
128111
log.error(e.getMessage());
129112
listener
@@ -138,6 +121,47 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
138121
}
139122
}
140123

124+
protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInput, ActionListener<MLBatchIngestionResponse> listener) {
125+
MLTask mlTask = MLTask
126+
.builder()
127+
.async(true)
128+
.taskType(MLTaskType.BATCH_INGEST)
129+
.createTime(Instant.now())
130+
.lastUpdateTime(Instant.now())
131+
.state(MLTaskState.CREATED)
132+
.build();
133+
134+
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
135+
String taskId = response.getId();
136+
try {
137+
mlTask.setTaskId(taskId);
138+
mlTaskManager.add(mlTask);
139+
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
140+
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
141+
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
142+
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
143+
executeWithErrorHandling(() -> {
144+
double successRate = ingestable.ingest(mlBatchIngestionInput);
145+
handleSuccessRate(successRate, taskId);
146+
}, taskId);
147+
});
148+
} catch (Exception ex) {
149+
log.error("Failed in batch ingestion", ex);
150+
mlTaskManager
151+
.updateMLTask(
152+
taskId,
153+
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
154+
TASK_SEMAPHORE_TIMEOUT,
155+
true
156+
);
157+
listener.onFailure(ex);
158+
}
159+
}, exception -> {
160+
log.error("Failed to create batch ingestion task", exception);
161+
listener.onFailure(exception);
162+
}));
163+
}
164+
141165
protected void executeWithErrorHandling(Runnable task, String taskId) {
142166
try {
143167
task.run();
@@ -190,6 +214,9 @@ private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInpu
190214
|| mlBatchIngestionInput.getDataSources().isEmpty()) {
191215
throw new IllegalArgumentException("The batch ingest input data source cannot be null");
192216
}
217+
if (mlBatchIngestionInput.getCredential() == null && mlBatchIngestionInput.getConnectorId() == null) {
218+
throw new IllegalArgumentException("The batch ingest credential or connector_id cannot be null");
219+
}
193220
Map<String, Object> dataSources = mlBatchIngestionInput.getDataSources();
194221
if (dataSources.get(TYPE) == null || dataSources.get(SOURCE) == null) {
195222
throw new IllegalArgumentException("The batch ingest input data source is missing data type or source");

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+8
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,14 @@ private void handleException(FunctionName functionName, String taskId, Exception
969969
mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true);
970970
}
971971

972+
public void getConnectorCredential(String connectorId, ActionListener<Map<String, String>> connectorCredentialListener) {
973+
getConnector(connectorId, ActionListener.wrap(connector -> {
974+
Map<String, String> credential = mlEngine.getConnectorCredential(connector);
975+
connectorCredentialListener.onResponse(credential);
976+
log.info("Completed loading credential in the connector {}", connectorId);
977+
}, connectorCredentialListener::onFailure));
978+
}
979+
972980
/**
973981
* Read model chunks from model index. Concat chunks into a whole model file,
974982
* then load

plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
4747
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
4848
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
49+
import org.opensearch.ml.model.MLModelManager;
4950
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
5051
import org.opensearch.ml.task.MLTaskManager;
5152
import org.opensearch.tasks.Task;
@@ -63,6 +64,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
6364
@Mock
6465
private MLTaskManager mlTaskManager;
6566
@Mock
67+
MLModelManager mlModelManager;
68+
@Mock
6669
private ActionFilters actionFilters;
6770
@Mock
6871
private MLBatchIngestionRequest mlBatchIngestionRequest;
@@ -90,6 +93,7 @@ public void setup() {
9093
client,
9194
mlTaskManager,
9295
threadPool,
96+
mlModelManager,
9397
mlFeatureEnabledSetting
9498
);
9599

0 commit comments

Comments
 (0)