Skip to content

Commit 39ceaef

Browse files
committed
use connector credential in offline batch ingestion
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent ff6fe67 commit 39ceaef

File tree

8 files changed

+197
-52
lines changed

8 files changed

+197
-52
lines changed

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

+2
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,6 @@ default void validateConnectorURL(List<String> urlRegexes) {
184184
}
185185

186186
Map<String, String> getDecryptedHeaders();
187+
188+
Map<String, String> getDecryptedCredential();
187189
}

common/src/main/java/org/opensearch/ml/common/transport/batch/MLBatchIngestionInput.java

+18-7
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,37 @@
2020

2121
import lombok.Builder;
2222
import lombok.Getter;
23+
import lombok.Setter;
2324

2425
/**
2526
* ML batch ingestion data: index, field mapping and input and out files.
2627
*/
28+
@Getter
2729
public class MLBatchIngestionInput implements ToXContentObject, Writeable {
2830

2931
public static final String INDEX_NAME_FIELD = "index_name";
3032
public static final String FIELD_MAP_FIELD = "field_map";
3133
public static final String INGEST_FIELDS = "ingest_fields";
3234
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
3335
public static final String DATA_SOURCE_FIELD = "data_source";
36+
public static final String CONNECTOR_ID_FIELD = "connector_id";
3437

35-
@Getter
3638
private String indexName;
37-
@Getter
3839
private Map<String, Object> fieldMapping;
39-
@Getter
4040
private String[] ingestFields;
41-
@Getter
4241
private Map<String, Object> dataSources;
43-
@Getter
42+
@Setter
4443
private Map<String, String> credential;
44+
private String connectorId;
4545

4646
@Builder(toBuilder = true)
4747
public MLBatchIngestionInput(
4848
String indexName,
4949
Map<String, Object> fieldMapping,
5050
String[] ingestFields,
5151
Map<String, Object> dataSources,
52-
Map<String, String> credential
52+
Map<String, String> credential,
53+
String connectorId
5354
) {
5455
if (indexName == null) {
5556
throw new IllegalArgumentException(
@@ -66,6 +67,7 @@ public MLBatchIngestionInput(
6667
this.ingestFields = ingestFields;
6768
this.dataSources = dataSources;
6869
this.credential = credential;
70+
this.connectorId = connectorId;
6971
}
7072

7173
public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
@@ -74,6 +76,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
7476
String[] ingestFields = null;
7577
Map<String, Object> dataSources = null;
7678
Map<String, String> credential = new HashMap<>();
79+
String connectorId = null;
7780

7881
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
7982
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -93,6 +96,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
9396
case CONNECTOR_CREDENTIAL_FIELD:
9497
credential = parser.mapStrings();
9598
break;
99+
case CONNECTOR_ID_FIELD:
100+
connectorId = parser.text();
101+
break;
96102
case DATA_SOURCE_FIELD:
97103
dataSources = parser.map();
98104
break;
@@ -101,7 +107,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
101107
break;
102108
}
103109
}
104-
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
110+
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential, connectorId);
105111
}
106112

107113
@Override
@@ -119,6 +125,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
119125
if (credential != null) {
120126
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
121127
}
128+
if (connectorId != null) {
129+
builder.field(CONNECTOR_ID_FIELD, connectorId);
130+
}
122131
if (dataSources != null) {
123132
builder.field(DATA_SOURCE_FIELD, dataSources);
124133
}
@@ -147,6 +156,7 @@ public void writeTo(StreamOutput output) throws IOException {
147156
} else {
148157
output.writeBoolean(false);
149158
}
159+
output.writeOptionalString(connectorId);
150160
if (dataSources != null) {
151161
output.writeBoolean(true);
152162
output.writeMap(dataSources, StreamOutput::writeString, StreamOutput::writeGenericValue);
@@ -166,6 +176,7 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
166176
if (input.readBoolean()) {
167177
credential = input.readMap(s -> s.readString(), s -> s.readString());
168178
}
179+
this.connectorId = input.readOptionalString();
169180
if (input.readBoolean()) {
170181
dataSources = input.readMap(s -> s.readString(), s -> s.readGenericValue());
171182
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

+14
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55

66
package org.opensearch.ml.engine;
77

8+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
9+
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
10+
811
import java.nio.file.Path;
912
import java.util.Locale;
1013
import java.util.Map;
1114

1215
import org.opensearch.core.action.ActionListener;
1316
import org.opensearch.ml.common.FunctionName;
1417
import org.opensearch.ml.common.MLModel;
18+
import org.opensearch.ml.common.connector.Connector;
1519
import org.opensearch.ml.common.dataframe.DataFrame;
1620
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
1721
import org.opensearch.ml.common.dataset.MLInputDataset;
@@ -120,6 +124,16 @@ public MLModel train(Input input) {
120124
return trainable.train(mlInput);
121125
}
122126

127+
public Map<String, String> getConnectorCredential(Connector connector) {
128+
connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
129+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
130+
String region = connector.getParameters().get(REGION_FIELD);
131+
if (region != null) {
132+
decryptedCredential.putIfAbsent(REGION_FIELD, region);
133+
}
134+
return decryptedCredential;
135+
}
136+
123137
public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
124138
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
125139
predictable.initModel(mlModel, params, encryptor);

ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java

+35
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import java.io.IOException;
1717
import java.nio.file.Path;
1818
import java.util.Arrays;
19+
import java.util.Collections;
20+
import java.util.Map;
1921
import java.util.UUID;
2022

2123
import org.junit.Assert;
@@ -24,11 +26,17 @@
2426
import org.junit.Test;
2527
import org.junit.rules.ExpectedException;
2628
import org.mockito.MockedStatic;
29+
import org.opensearch.common.settings.Settings;
30+
import org.opensearch.common.xcontent.XContentType;
2731
import org.opensearch.core.action.ActionListener;
2832
import org.opensearch.core.common.io.stream.StreamOutput;
33+
import org.opensearch.core.xcontent.NamedXContentRegistry;
2934
import org.opensearch.core.xcontent.XContentBuilder;
35+
import org.opensearch.core.xcontent.XContentParser;
3036
import org.opensearch.ml.common.FunctionName;
3137
import org.opensearch.ml.common.MLModel;
38+
import org.opensearch.ml.common.connector.Connector;
39+
import org.opensearch.ml.common.connector.HttpConnector;
3240
import org.opensearch.ml.common.dataframe.ColumnMeta;
3341
import org.opensearch.ml.common.dataframe.DataFrame;
3442
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
@@ -47,6 +55,7 @@
4755
import org.opensearch.ml.engine.algorithms.regression.LinearRegression;
4856
import org.opensearch.ml.engine.encryptor.Encryptor;
4957
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
58+
import org.opensearch.search.SearchModule;
5059

5160
// TODO: refactor MLEngineClassLoader's static functions to avoid mockStatic
5261
public class MLEngineTest extends MLStaticMockBase {
@@ -408,4 +417,30 @@ public void testEncryptMethod() {
408417
assertNotEquals(testString, encryptedString);
409418
}
410419

420+
@Test
421+
public void testGetConnectorCredential() throws IOException {
422+
String encryptedValue = mlEngine.encrypt("test_key_value");
423+
String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\","
424+
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
425+
+ "\"parameters\":{\"region\":\"test region\"},\"credential\":{\"key\":\"" + encryptedValue + "\"},"
426+
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
427+
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
428+
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}],"
429+
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";
430+
431+
XContentParser parser = XContentType.JSON
432+
.xContent()
433+
.createParser(
434+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
435+
null,
436+
test_connector_string
437+
);
438+
parser.nextToken();
439+
440+
HttpConnector connector = new HttpConnector("http", parser);
441+
Map<String, String> decryptedCredential = mlEngine.getConnectorCredential(connector);
442+
assertNotNull(decryptedCredential);
443+
assertEquals(decryptedCredential.get("key"), "test_key_value");
444+
assertEquals(decryptedCredential.get("region"), "test region");
445+
}
411446
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/ingest/AbstractIngestionTests.java

+10-5
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ public void testFilterFieldMapping_ValidInput_EmptyPrefix() {
166166
fieldMap,
167167
ingestFields,
168168
new HashMap<>(),
169-
new HashMap<>()
169+
new HashMap<>(),
170+
null
170171
);
171172
Map<String, Object> result = s3DataIngestion.filterFieldMapping(mlBatchIngestionInput, 0);
172173

@@ -190,7 +191,8 @@ public void testFilterFieldMapping_MatchingPrefix() {
190191
fieldMap,
191192
ingestFields,
192193
new HashMap<>(),
193-
new HashMap<>()
194+
new HashMap<>(),
195+
null
194196
);
195197

196198
// Act
@@ -219,7 +221,8 @@ public void testFilterFieldMappingSoleSource_MatchingPrefix() {
219221
fieldMap,
220222
ingestFields,
221223
new HashMap<>(),
222-
new HashMap<>()
224+
new HashMap<>(),
225+
null
223226
);
224227

225228
// Act
@@ -292,7 +295,8 @@ public void testBatchIngestSuccess_SoleSource() {
292295
fieldMap,
293296
ingestFields,
294297
new HashMap<>(),
295-
new HashMap<>()
298+
new HashMap<>(),
299+
null
296300
);
297301
ActionListener<BulkResponse> bulkResponseListener = mock(ActionListener.class);
298302
s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, true);
@@ -318,7 +322,8 @@ public void testBatchIngestSuccess_returnForNullJasonMap() {
318322
fieldMap,
319323
ingestFields,
320324
new HashMap<>(),
321-
new HashMap<>()
325+
new HashMap<>(),
326+
null
322327
);
323328
ActionListener<BulkResponse> bulkResponseListener = mock(ActionListener.class);
324329
s3DataIngestion.batchIngest(sourceLines, mlBatchIngestionInput, bulkResponseListener, 0, false);

0 commit comments

Comments
 (0)