Skip to content

Commit 4f70041

Browse files
committed
adding multi-tenancy to config api and master key related changes
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent af96fe0 commit 4f70041

File tree

50 files changed

+448
-265
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+448
-265
lines changed

common/build.gradle

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ dependencies {
3939
compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
4040
compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
4141
compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
42+
// Multi-tenant SDK Client
43+
compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
4244
}
4345

4446
lombok {

common/src/main/java/org/opensearch/ml/common/MLConfig.java

+18-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package org.opensearch.ml.common;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
10+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
911

1012
import java.io.IOException;
1113
import java.time.Instant;
@@ -57,6 +59,7 @@ public class MLConfig implements ToXContentObject, Writeable {
5759
private final Instant createTime;
5860
private Instant lastUpdateTime;
5961
private Instant lastUpdatedTime;
62+
private final String tenantId;
6063

6164
@Builder(toBuilder = true)
6265
public MLConfig(
@@ -66,7 +69,8 @@ public MLConfig(
6669
Configuration mlConfiguration,
6770
Instant createTime,
6871
Instant lastUpdateTime,
69-
Instant lastUpdatedTime
72+
Instant lastUpdatedTime,
73+
String tenantId
7074
) {
7175
this.type = type;
7276
this.configType = configType;
@@ -75,6 +79,7 @@ public MLConfig(
7579
this.createTime = createTime;
7680
this.lastUpdateTime = lastUpdateTime;
7781
this.lastUpdatedTime = lastUpdatedTime;
82+
this.tenantId = tenantId;
7883
}
7984

8085
public MLConfig(StreamInput input) throws IOException {
@@ -92,6 +97,7 @@ public MLConfig(StreamInput input) throws IOException {
9297
}
9398
lastUpdatedTime = input.readOptionalInstant();
9499
}
100+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
95101
}
96102

97103
@Override
@@ -116,6 +122,9 @@ public void writeTo(StreamOutput out) throws IOException {
116122
}
117123
out.writeOptionalInstant(lastUpdatedTime);
118124
}
125+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
126+
out.writeOptionalString(tenantId);
127+
}
119128
}
120129

121130
@Override
@@ -133,12 +142,14 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
133142
if (lastUpdateTime != null || lastUpdatedTime != null) {
134143
builder.field(LAST_UPDATE_TIME_FIELD, lastUpdatedTime == null ? lastUpdateTime.toEpochMilli() : lastUpdatedTime.toEpochMilli());
135144
}
145+
if (tenantId != null) {
146+
builder.field(TENANT_ID_FIELD, tenantId);
147+
}
136148
return builder.endObject();
137149
}
138150

139151
public static MLConfig fromStream(StreamInput in) throws IOException {
140-
MLConfig mlConfig = new MLConfig(in);
141-
return mlConfig;
152+
return new MLConfig(in);
142153
}
143154

144155
public static MLConfig parse(XContentParser parser) throws IOException {
@@ -149,6 +160,7 @@ public static MLConfig parse(XContentParser parser) throws IOException {
149160
Instant createTime = null;
150161
Instant lastUpdateTime = null;
151162
Instant lastUpdatedTime = null;
163+
String tenantId = null;
152164

153165
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
154166
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -177,6 +189,8 @@ public static MLConfig parse(XContentParser parser) throws IOException {
177189
case LAST_UPDATED_TIME_FIELD:
178190
lastUpdatedTime = Instant.ofEpochMilli(parser.longValue());
179191
break;
192+
case TENANT_ID_FIELD:
193+
tenantId = parser.textOrNull();
180194
default:
181195
parser.skipChildren();
182196
break;
@@ -191,6 +205,7 @@ public static MLConfig parse(XContentParser parser) throws IOException {
191205
.createTime(createTime)
192206
.lastUpdateTime(lastUpdateTime)
193207
.lastUpdatedTime(lastUpdatedTime)
208+
.tenantId(tenantId)
194209
.build();
195210
}
196211
}

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import java.util.List;
1717
import java.util.Map;
1818
import java.util.Optional;
19-
import java.util.function.Function;
19+
import java.util.function.BiFunction;
2020
import java.util.regex.Matcher;
2121
import java.util.regex.Pattern;
2222

@@ -79,9 +79,9 @@ public interface Connector extends ToXContentObject, Writeable {
7979

8080
<T> T createPayload(String action, Map<String, String> parameters);
8181

82-
void decrypt(String action, Function<String, String> function);
82+
void decrypt(String action, BiFunction<String, String, String> function, String tenantId);
8383

84-
void encrypt(Function<String, String> function);
84+
void encrypt(BiFunction<String, String, String> function, String tenantId);
8585

8686
Connector cloneConnector();
8787

@@ -91,7 +91,7 @@ public interface Connector extends ToXContentObject, Writeable {
9191

9292
void writeTo(StreamOutput out) throws IOException;
9393

94-
void update(MLCreateConnectorInput updateContent, Function<String, String> function);
94+
void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function);
9595

9696
<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;
9797

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.Optional;
24-
import java.util.function.Function;
24+
import java.util.function.BiFunction;
2525
import java.util.regex.Matcher;
2626
import java.util.regex.Pattern;
2727

@@ -302,7 +302,7 @@ public void writeTo(StreamOutput out) throws IOException {
302302
}
303303

304304
@Override
305-
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
305+
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function) {
306306
if (updateContent.getName() != null) {
307307
this.name = updateContent.getName();
308308
}
@@ -320,7 +320,7 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
320320
}
321321
if (updateContent.getCredential() != null && !updateContent.getCredential().isEmpty()) {
322322
this.credential = updateContent.getCredential();
323-
encrypt(function);
323+
encrypt(function, this.tenantId);
324324
}
325325
if (updateContent.getActions() != null) {
326326
this.actions = updateContent.getActions();
@@ -379,10 +379,10 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
379379
}
380380

381381
@Override
382-
public void decrypt(String action, Function<String, String> function) {
382+
public void decrypt(String action, BiFunction<String, String, String> function, String tenantId) {
383383
Map<String, String> decrypted = new HashMap<>();
384384
for (String key : credential.keySet()) {
385-
decrypted.put(key, function.apply(credential.get(key)));
385+
decrypted.put(key, function.apply(credential.get(key), tenantId));
386386
}
387387
this.decryptedCredential = decrypted;
388388
Optional<ConnectorAction> connectorAction = findAction(action);
@@ -402,9 +402,9 @@ public Connector cloneConnector() {
402402
}
403403

404404
@Override
405-
public void encrypt(Function<String, String> function) {
405+
public void encrypt(BiFunction<String, String, String> function, String tenantId) {
406406
for (String key : credential.keySet()) {
407-
String encrypted = function.apply(credential.get(key));
407+
String encrypted = function.apply(credential.get(key), tenantId);
408408
credential.put(key, encrypted);
409409
}
410410
}

common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java

+11-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
package org.opensearch.ml.common.transport.config;
77

88
import static org.opensearch.action.ValidateActions.addValidationError;
9+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
910

1011
import java.io.ByteArrayInputStream;
1112
import java.io.ByteArrayOutputStream;
1213
import java.io.IOException;
1314
import java.io.UncheckedIOException;
1415

16+
import org.opensearch.Version;
1517
import org.opensearch.action.ActionRequest;
1618
import org.opensearch.action.ActionRequestValidationException;
1719
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
@@ -26,21 +28,29 @@
2628
public class MLConfigGetRequest extends ActionRequest {
2729

2830
String configId;
31+
String tenantId;
2932

3033
@Builder
31-
public MLConfigGetRequest(String configId) {
34+
public MLConfigGetRequest(String configId, String tenantId) {
3235
this.configId = configId;
36+
this.tenantId = tenantId;
3337
}
3438

3539
public MLConfigGetRequest(StreamInput in) throws IOException {
3640
super(in);
41+
Version streamInputVersion = in.getVersion();
3742
this.configId = in.readString();
43+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
3844
}
3945

4046
@Override
4147
public void writeTo(StreamOutput out) throws IOException {
4248
super.writeTo(out);
49+
Version streamOutputVersion = out.getVersion();
4350
out.writeString(this.configId);
51+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
52+
out.writeOptionalString(tenantId);
53+
}
4454
}
4555

4656
@Override

common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java

+11-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import java.util.HashMap;
1919
import java.util.Locale;
2020
import java.util.Map;
21-
import java.util.function.Function;
21+
import java.util.function.BiFunction;
2222

2323
import org.junit.Assert;
2424
import org.junit.Before;
@@ -39,13 +39,13 @@ public class AwsConnectorTest {
3939
@Rule
4040
public ExpectedException exceptionRule = ExpectedException.none();
4141

42-
Function<String, String> encryptFunction;
43-
Function<String, String> decryptFunction;
42+
BiFunction<String, String, String> encryptFunction;
43+
BiFunction<String, String, String> decryptFunction;
4444

4545
@Before
4646
public void setUp() {
47-
encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT);
48-
decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT);
47+
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
48+
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
4949
}
5050

5151
@Test
@@ -115,8 +115,8 @@ public void constructor_NoPredictAction() {
115115
.build();
116116
Assert.assertNotNull(connector);
117117

118-
connector.encrypt(encryptFunction);
119-
connector.decrypt(PREDICT.name(), decryptFunction);
118+
connector.encrypt(encryptFunction, null);
119+
connector.decrypt(PREDICT.name(), decryptFunction, null);
120120
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
121121
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
122122
Assert.assertEquals(null, connector.getSessionToken());
@@ -159,8 +159,8 @@ public void constructor() {
159159
String url = "https://${parameters.endpoint}/model1";
160160

161161
AwsConnector connector = createAwsConnector(parameters, credential, url);
162-
connector.encrypt(encryptFunction);
163-
connector.decrypt(PREDICT.name(), decryptFunction);
162+
connector.encrypt(encryptFunction, null);
163+
connector.decrypt(PREDICT.name(), decryptFunction, null);
164164
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
165165
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
166166
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());
@@ -180,8 +180,8 @@ public void constructor_NoParameter() {
180180

181181
String url = "https://test.com";
182182
AwsConnector connector = createAwsConnector(null, credential, url);
183-
connector.encrypt(encryptFunction);
184-
connector.decrypt(PREDICT.name(), decryptFunction);
183+
connector.encrypt(encryptFunction, null);
184+
connector.decrypt(PREDICT.name(), decryptFunction, null);
185185
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
186186
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
187187
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import java.util.List;
1616
import java.util.Locale;
1717
import java.util.Map;
18-
import java.util.function.Function;
18+
import java.util.function.BiFunction;
1919

2020
import org.junit.Assert;
2121
import org.junit.Before;
@@ -38,8 +38,8 @@ public class HttpConnectorTest {
3838
@Rule
3939
public ExpectedException exceptionRule = ExpectedException.none();
4040

41-
Function<String, String> encryptFunction;
42-
Function<String, String> decryptFunction;
41+
BiFunction<String, String, String> encryptFunction;
42+
BiFunction<String, String, String> decryptFunction;
4343

4444
String TEST_CONNECTOR_JSON_STRING = "{\"name\":\"test_connector_name\",\"version\":\"1\","
4545
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
@@ -55,8 +55,8 @@ public class HttpConnectorTest {
5555

5656
@Before
5757
public void setUp() {
58-
encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT);
59-
decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT);
58+
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
59+
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
6060
}
6161

6262
@Test
@@ -124,7 +124,7 @@ public void cloneConnector() {
124124
@Test
125125
public void decrypt() {
126126
HttpConnector connector = createHttpConnector();
127-
connector.decrypt(PREDICT.name(), decryptFunction);
127+
connector.decrypt(PREDICT.name(), decryptFunction, null);
128128
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
129129
Assert.assertEquals(1, decryptedCredential.size());
130130
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
@@ -141,7 +141,7 @@ public void decrypt() {
141141
@Test
142142
public void encrypted() {
143143
HttpConnector connector = createHttpConnector();
144-
connector.encrypt(encryptFunction);
144+
connector.encrypt(encryptFunction, null);
145145
Map<String, String> credential = connector.getCredential();
146146
Assert.assertEquals(1, credential.size());
147147
Assert.assertEquals("encrypted: test_key_value", credential.get("key"));

0 commit comments

Comments
 (0)