Skip to content

Commit b12a536

Browse files
committed
use connector credential in offline batch ingestion
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent d7e0fe4 commit b12a536

File tree

6 files changed

+112
-44
lines changed

6 files changed

+112
-44
lines changed

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

+2
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,6 @@ default void validateConnectorURL(List<String> urlRegexes) {
179179
}
180180

181181
Map<String, String> getDecryptedHeaders();
182+
183+
Map<String, String> getDecryptedCredential();
182184
}

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+18-7
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,37 @@
2020

2121
import lombok.Builder;
2222
import lombok.Getter;
23+
import lombok.Setter;
2324

2425
/**
2526
* ML batch ingestion data: index, field mapping and input and out files.
2627
*/
28+
@Getter
2729
public class MLBatchIngestionInput implements ToXContentObject, Writeable {
2830

2931
public static final String INDEX_NAME_FIELD = "index_name";
3032
public static final String FIELD_MAP_FIELD = "field_map";
3133
public static final String INGEST_FIELDS = "ingest_fields";
3234
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
3335
public static final String DATA_SOURCE_FIELD = "data_source";
36+
public static final String CONNECTOR_ID_FIELD = "connector_id";
3437

35-
@Getter
3638
private String indexName;
37-
@Getter
3839
private Map<String, Object> fieldMapping;
39-
@Getter
4040
private String[] ingestFields;
41-
@Getter
4241
private Map<String, Object> dataSources;
43-
@Getter
42+
@Setter
4443
private Map<String, String> credential;
44+
private String connectorId;
4545

4646
@Builder(toBuilder = true)
4747
public MLBatchIngestionInput(
4848
String indexName,
4949
Map<String, Object> fieldMapping,
5050
String[] ingestFields,
5151
Map<String, Object> dataSources,
52-
Map<String, String> credential
52+
Map<String, String> credential,
53+
String connectorId
5354
) {
5455
if (indexName == null) {
5556
throw new IllegalArgumentException(
@@ -66,6 +67,7 @@ public MLBatchIngestionInput(
6667
this.ingestFields = ingestFields;
6768
this.dataSources = dataSources;
6869
this.credential = credential;
70+
this.connectorId = connectorId;
6971
}
7072

7173
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
@@ -74,6 +76,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
7476
String[] ingestFields = null;
7577
Map<String, Object> dataSources = null;
7678
Map<String, String> credential = new HashMap<>();
79+
String connectorId = null;
7780

7881
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
7982
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -93,6 +96,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
9396
case CONNECTOR_CREDENTIAL_FIELD:
9497
credential = parser.mapStrings();
9598
break;
99+
case CONNECTOR_ID_FIELD:
100+
connectorId = parser.text();
101+
break;
96102
case DATA_SOURCE_FIELD:
97103
dataSources = parser.map();
98104
break;
@@ -101,7 +107,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
101107
break;
102108
}
103109
}
104-
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
110+
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential, connectorId);
105111
}
106112

107113
@Override
@@ -119,6 +125,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
119125
if (credential != null) {
120126
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
121127
}
128+
if (connectorId != null) {
129+
builder.field(CONNECTOR_ID_FIELD, connectorId);
130+
}
122131
if (dataSources != null) {
123132
builder.field(DATA_SOURCE_FIELD, dataSources);
124133
}
@@ -147,6 +156,7 @@ public void writeTo(StreamOutput output) throws IOException {
147156
} else {
148157
output.writeBoolean(false);
149158
}
159+
output.writeOptionalString(connectorId);
150160
if (dataSources != null) {
151161
output.writeBoolean(true);
152162
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
@@ -166,6 +176,7 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
166176
if (input.readBoolean()) {
167177
credential = input.readMap(s -> s.readString(), s -> s.readString());
168178
}
179+
this.connectorId = input.readOptionalString();
169180
if (input.readBoolean()) {
170181
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
171182
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

+15
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55

66
package org.opensearch.ml.engine;
77

8+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
9+
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
10+
811
import java.nio.file.Path;
912
import java.util.Locale;
1013
import java.util.Map;
1114

1215
import org.opensearch.core.action.ActionListener;
1316
import org.opensearch.ml.common.FunctionName;
1417
import org.opensearch.ml.common.MLModel;
18+
import org.opensearch.ml.common.connector.Connector;
1519
import org.opensearch.ml.common.dataframe.DataFrame;
1620
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
1721
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -120,6 +124,17 @@ public MLModel train(Input input) {
120124
return trainable.train(mlInput);
121125
}
122126

127+
public Map<String, String> getConnectorCredential(Connector connector) {
128+
connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
129+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
130+
String region = connector.getParameters().get(REGION_FIELD);
131+
if (region != null) {
132+
decryptedCredential.putIfAbsent(REGION_FIELD, region);
133+
}
134+
135+
return decryptedCredential;
136+
}
137+
123138
public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
124139
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
125140
predictable.initModel(mlModel, params, encryptor);

plugin/src/main/java/org/opensearch/ml/action/batch/TransportBatchIngestionAction.java

+65-37
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
3737
import org.opensearch.ml.engine.MLEngineClassLoader;
3838
import org.opensearch.ml.engine.ingest.Ingestable;
39+
import org.opensearch.ml.model.MLModelManager;
3940
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
4041
import org.opensearch.ml.task.MLTaskManager;
4142
import org.opensearch.ml.utils.MLExceptionUtils;
@@ -55,6 +56,7 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
5556
public static final String SOURCE = "source";
5657
TransportService transportService;
5758
MLTaskManager mlTaskManager;
59+
MLModelManager mlModelManager;
5860
private final Client client;
5961
private ThreadPool threadPool;
6062
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
@@ -66,13 +68,15 @@ public TransportBatchIngestionAction(
6668
Client client,
6769
MLTaskManager mlTaskManager,
6870
ThreadPool threadPool,
71+
MLModelManager mlModelManager,
6972
MLFeatureEnabledSetting mlFeatureEnabledSetting
7073
) {
7174
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
7275
this.transportService = transportService;
7376
this.client = client;
7477
this.mlTaskManager = mlTaskManager;
7578
this.threadPool = threadPool;
79+
this.mlModelManager = mlModelManager;
7680
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
7781
}
7882

@@ -85,44 +89,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
8589
throw new IllegalStateException(OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG);
8690
}
8791
validateBatchIngestInput(mlBatchIngestionInput);
88-
MLTask mlTask = MLTask
89-
.builder()
90-
.async(true)
91-
.taskType(MLTaskType.BATCH_INGEST)
92-
.createTime(Instant.now())
93-
.lastUpdateTime(Instant.now())
94-
.state(MLTaskState.CREATED)
95-
.build();
96-
97-
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
98-
String taskId = response.getId();
99-
try {
100-
mlTask.setTaskId(taskId);
101-
mlTaskManager.add(mlTask);
102-
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
103-
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
104-
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
105-
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
106-
executeWithErrorHandling(() -> {
107-
double successRate = ingestable.ingest(mlBatchIngestionInput);
108-
handleSuccessRate(successRate, taskId);
109-
}, taskId);
110-
});
111-
} catch (Exception ex) {
112-
log.error("Failed in batch ingestion", ex);
113-
mlTaskManager
114-
.updateMLTask(
115-
taskId,
116-
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
117-
TASK_SEMAPHORE_TIMEOUT,
118-
true
92+
93+
if (mlBatchIngestionInput.getConnectorId() != null && mlBatchIngestionInput.getCredential() == null) {
94+
mlModelManager.getConnectorCredential(mlBatchIngestionInput.getConnectorId(), ActionListener.wrap(credentialMap -> {
95+
mlBatchIngestionInput.setCredential(credentialMap);
96+
createMLTaskandExecute(mlBatchIngestionInput, listener);
97+
}, e -> {
98+
log.error(e.getMessage());
99+
listener
100+
.onFailure(
101+
new OpenSearchStatusException(
102+
"Fail to fetch credentials from the connector in the batch ingestion input: " + e.getMessage(),
103+
RestStatus.BAD_REQUEST
104+
)
119105
);
120-
listener.onFailure(ex);
121-
}
122-
}, exception -> {
123-
log.error("Failed to create batch ingestion task", exception);
124-
listener.onFailure(exception);
125-
}));
106+
}));
107+
}
108+
109+
createMLTaskandExecute(mlBatchIngestionInput, listener);
126110
} catch (IllegalArgumentException e) {
127111
log.error(e.getMessage());
128112
listener
@@ -137,6 +121,47 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatc
137121
}
138122
}
139123

124+
protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInput, ActionListener<MLBatchIngestionResponse> listener) {
125+
MLTask mlTask = MLTask
126+
.builder()
127+
.async(true)
128+
.taskType(MLTaskType.BATCH_INGEST)
129+
.createTime(Instant.now())
130+
.lastUpdateTime(Instant.now())
131+
.state(MLTaskState.CREATED)
132+
.build();
133+
134+
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
135+
String taskId = response.getId();
136+
try {
137+
mlTask.setTaskId(taskId);
138+
mlTaskManager.add(mlTask);
139+
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
140+
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
141+
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
142+
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
143+
executeWithErrorHandling(() -> {
144+
double successRate = ingestable.ingest(mlBatchIngestionInput);
145+
handleSuccessRate(successRate, taskId);
146+
}, taskId);
147+
});
148+
} catch (Exception ex) {
149+
log.error("Failed in batch ingestion", ex);
150+
mlTaskManager
151+
.updateMLTask(
152+
taskId,
153+
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
154+
TASK_SEMAPHORE_TIMEOUT,
155+
true
156+
);
157+
listener.onFailure(ex);
158+
}
159+
}, exception -> {
160+
log.error("Failed to create batch ingestion task", exception);
161+
listener.onFailure(exception);
162+
}));
163+
}
164+
140165
protected void executeWithErrorHandling(Runnable task, String taskId) {
141166
try {
142167
task.run();
@@ -189,6 +214,9 @@ private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInpu
189214
|| mlBatchIngestionInput.getDataSources().isEmpty()) {
190215
throw new IllegalArgumentException("The batch ingest input data source cannot be null");
191216
}
217+
if (mlBatchIngestionInput.getCredential() == null && mlBatchIngestionInput.getConnectorId() == null) {
218+
throw new IllegalArgumentException("The batch ingest credential or connector_id cannot be null");
219+
}
192220
Map<String, Object> dataSources = mlBatchIngestionInput.getDataSources();
193221
if (dataSources.get(TYPE) == null || dataSources.get(SOURCE) == null) {
194222
throw new IllegalArgumentException("The batch ingest input data source is missing data type or source");

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+8
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,14 @@ private void handleException(FunctionName functionName, String taskId, Exception
965965
mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true);
966966
}
967967

968+
public void getConnectorCredential(String connectorId, ActionListener<Map<String, String>> connectorCredentialListener) {
969+
getConnector(connectorId, ActionListener.wrap(connector -> {
970+
Map<String, String> credential = mlEngine.getConnectorCredential(connector);
971+
connectorCredentialListener.onResponse(credential);
972+
log.info("Completed loading credential in the connector {}", connectorId);
973+
}, connectorCredentialListener::onFailure));
974+
}
975+
968976
/**
969977
* Read model chunks from model index. Concat chunks into a whole model file,
970978
* then load

plugin/src/test/java/org/opensearch/ml/action/batch/TransportBatchIngestionActionTests.java

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
4747
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
4848
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
49+
import org.opensearch.ml.model.MLModelManager;
4950
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
5051
import org.opensearch.ml.task.MLTaskManager;
5152
import org.opensearch.tasks.Task;
@@ -63,6 +64,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
6364
@Mock
6465
private MLTaskManager mlTaskManager;
6566
@Mock
67+
MLModelManager mlModelManager;
68+
@Mock
6669
private ActionFilters actionFilters;
6770
@Mock
6871
private MLBatchIngestionRequest mlBatchIngestionRequest;
@@ -90,6 +93,7 @@ public void setup() {
9093
client,
9194
mlTaskManager,
9295
threadPool,
96+
mlModelManager,
9397
mlFeatureEnabledSetting
9498
);
9599

0 commit comments

Comments
 (0)