Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] adding multi-tenancy to config api and master key related changes #3444

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
adding multi-tenancy to config api and master key related changes (#3439
)

* adding multi-tenancy to config api and master key related changes

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* adding more unit tests

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

---------

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
(cherry picked from commit 9846e6e)
dhrubo-os authored and github-actions[bot] committed Jan 28, 2025
commit cd006f1bc50696f56fe8cdc3798e32285dae232d
2 changes: 2 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
@@ -39,6 +39,8 @@ dependencies {
compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
// Multi-tenant SDK Client
compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
}

lombok {
21 changes: 18 additions & 3 deletions common/src/main/java/org/opensearch/ml/common/MLConfig.java
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@
package org.opensearch.ml.common;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.IOException;
import java.time.Instant;
@@ -57,6 +59,7 @@ public class MLConfig implements ToXContentObject, Writeable {
private final Instant createTime;
private Instant lastUpdateTime;
private Instant lastUpdatedTime;
private final String tenantId;

@Builder(toBuilder = true)
public MLConfig(
@@ -66,7 +69,8 @@ public MLConfig(
Configuration mlConfiguration,
Instant createTime,
Instant lastUpdateTime,
Instant lastUpdatedTime
Instant lastUpdatedTime,
String tenantId
) {
this.type = type;
this.configType = configType;
@@ -75,6 +79,7 @@ public MLConfig(
this.createTime = createTime;
this.lastUpdateTime = lastUpdateTime;
this.lastUpdatedTime = lastUpdatedTime;
this.tenantId = tenantId;
}

public MLConfig(StreamInput input) throws IOException {
@@ -92,6 +97,7 @@ public MLConfig(StreamInput input) throws IOException {
}
lastUpdatedTime = input.readOptionalInstant();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

@Override
@@ -116,6 +122,9 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeOptionalInstant(lastUpdatedTime);
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
@@ -133,12 +142,14 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (lastUpdateTime != null || lastUpdatedTime != null) {
builder.field(LAST_UPDATE_TIME_FIELD, lastUpdatedTime == null ? lastUpdateTime.toEpochMilli() : lastUpdatedTime.toEpochMilli());
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
return builder.endObject();
}

public static MLConfig fromStream(StreamInput in) throws IOException {
MLConfig mlConfig = new MLConfig(in);
return mlConfig;
return new MLConfig(in);
}

public static MLConfig parse(XContentParser parser) throws IOException {
@@ -149,6 +160,7 @@ public static MLConfig parse(XContentParser parser) throws IOException {
Instant createTime = null;
Instant lastUpdateTime = null;
Instant lastUpdatedTime = null;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -177,6 +189,8 @@ public static MLConfig parse(XContentParser parser) throws IOException {
case LAST_UPDATED_TIME_FIELD:
lastUpdatedTime = Instant.ofEpochMilli(parser.longValue());
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
default:
parser.skipChildren();
break;
@@ -191,6 +205,7 @@ public static MLConfig parse(XContentParser parser) throws IOException {
.createTime(createTime)
.lastUpdateTime(lastUpdateTime)
.lastUpdatedTime(lastUpdatedTime)
.tenantId(tenantId)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@@ -79,9 +79,9 @@ public interface Connector extends ToXContentObject, Writeable {

<T> T createPayload(String action, Map<String, String> parameters);

void decrypt(String action, Function<String, String> function);
void decrypt(String action, BiFunction<String, String, String> function, String tenantId);

void encrypt(Function<String, String> function);
void encrypt(BiFunction<String, String, String> function, String tenantId);

Connector cloneConnector();

@@ -91,7 +91,7 @@ public interface Connector extends ToXContentObject, Writeable {

void writeTo(StreamOutput out) throws IOException;

void update(MLCreateConnectorInput updateContent, Function<String, String> function);
void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function);

<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;

Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@@ -300,7 +300,7 @@ public void writeTo(StreamOutput out) throws IOException {
}

@Override
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
@@ -318,7 +318,7 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
}
if (updateContent.getCredential() != null && !updateContent.getCredential().isEmpty()) {
this.credential = updateContent.getCredential();
encrypt(function);
encrypt(function, this.tenantId);
}
if (updateContent.getActions() != null) {
this.actions = updateContent.getActions();
@@ -377,10 +377,10 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
}

@Override
public void decrypt(String action, Function<String, String> function) {
public void decrypt(String action, BiFunction<String, String, String> function, String tenantId) {
Map<String, String> decrypted = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key)));
decrypted.put(key, function.apply(credential.get(key), tenantId));
}
this.decryptedCredential = decrypted;
Optional<ConnectorAction> connectorAction = findAction(action);
@@ -400,9 +400,9 @@ public Connector cloneConnector() {
}

@Override
public void encrypt(Function<String, String> function) {
public void encrypt(BiFunction<String, String, String> function, String tenantId) {
for (String key : credential.keySet()) {
String encrypted = function.apply(credential.get(key));
String encrypted = function.apply(credential.get(key), tenantId);
credential.put(key, encrypted);
}
}
Original file line number Diff line number Diff line change
@@ -6,12 +6,14 @@
package org.opensearch.ml.common.transport.config;

import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
@@ -26,21 +28,29 @@
public class MLConfigGetRequest extends ActionRequest {

String configId;
String tenantId;

@Builder
public MLConfigGetRequest(String configId) {
public MLConfigGetRequest(String configId, String tenantId) {
this.configId = configId;
this.tenantId = tenantId;
}

public MLConfigGetRequest(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.configId = in.readString();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Version streamOutputVersion = out.getVersion();
out.writeString(this.configId);
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Original file line number Diff line number Diff line change
@@ -8,9 +8,12 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -477,4 +480,21 @@ public static void validateSchema(String schemaString, String instanceString) {
throw new OpenSearchParseException("Schema validation failed: " + e.getMessage(), e);
}
}

public static String hashString(String input) {
try {
// Create a MessageDigest instance for SHA-256
MessageDigest digest = MessageDigest.getInstance("SHA-256");

// Perform the hashing and get the byte array
byte[] hashBytes = digest.digest(input.getBytes());

// Convert the byte array to a Base64 encoded string
return Base64.getUrlEncoder().encodeToString(hashBytes);

} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Error: Unable to compute hash", e);
}
}

}
Loading