Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding multi-tenancy to config api and master key related changes #3439

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ dependencies {
compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
// Multi-tenant SDK Client
compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
}

lombok {
Expand Down
21 changes: 18 additions & 3 deletions common/src/main/java/org/opensearch/ml/common/MLConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -57,6 +59,7 @@ public class MLConfig implements ToXContentObject, Writeable {
private final Instant createTime;
private Instant lastUpdateTime;
private Instant lastUpdatedTime;
private final String tenantId;

@Builder(toBuilder = true)
public MLConfig(
Expand All @@ -66,7 +69,8 @@ public MLConfig(
Configuration mlConfiguration,
Instant createTime,
Instant lastUpdateTime,
Instant lastUpdatedTime
Instant lastUpdatedTime,
String tenantId
) {
this.type = type;
this.configType = configType;
Expand All @@ -75,6 +79,7 @@ public MLConfig(
this.createTime = createTime;
this.lastUpdateTime = lastUpdateTime;
this.lastUpdatedTime = lastUpdatedTime;
this.tenantId = tenantId;
}

public MLConfig(StreamInput input) throws IOException {
Expand All @@ -92,6 +97,7 @@ public MLConfig(StreamInput input) throws IOException {
}
lastUpdatedTime = input.readOptionalInstant();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

@Override
Expand All @@ -116,6 +122,9 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeOptionalInstant(lastUpdatedTime);
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand All @@ -133,12 +142,14 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (lastUpdateTime != null || lastUpdatedTime != null) {
builder.field(LAST_UPDATE_TIME_FIELD, lastUpdatedTime == null ? lastUpdateTime.toEpochMilli() : lastUpdatedTime.toEpochMilli());
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
return builder.endObject();
}

public static MLConfig fromStream(StreamInput in) throws IOException {
MLConfig mlConfig = new MLConfig(in);
return mlConfig;
return new MLConfig(in);
}

public static MLConfig parse(XContentParser parser) throws IOException {
Expand All @@ -149,6 +160,7 @@ public static MLConfig parse(XContentParser parser) throws IOException {
Instant createTime = null;
Instant lastUpdateTime = null;
Instant lastUpdatedTime = null;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -177,6 +189,8 @@ public static MLConfig parse(XContentParser parser) throws IOException {
case LAST_UPDATED_TIME_FIELD:
lastUpdatedTime = Instant.ofEpochMilli(parser.longValue());
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
Comment on lines +192 to +193
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the tenant ID field added to the config? We use a hash of the key as part of the document ID and never actually compare vs. this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had is here in the multi-tenancy branch, I will deep dive later (after merging all the code) if we can remove it from here.

default:
parser.skipChildren();
break;
Expand All @@ -191,6 +205,7 @@ public static MLConfig parse(XContentParser parser) throws IOException {
.createTime(createTime)
.lastUpdateTime(lastUpdateTime)
.lastUpdatedTime(lastUpdatedTime)
.tenantId(tenantId)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -79,9 +79,9 @@ public interface Connector extends ToXContentObject, Writeable {

<T> T createPayload(String action, Map<String, String> parameters);

void decrypt(String action, Function<String, String> function);
void decrypt(String action, BiFunction<String, String, String> function, String tenantId);

void encrypt(Function<String, String> function);
void encrypt(BiFunction<String, String, String> function, String tenantId);

Connector cloneConnector();

Expand All @@ -91,7 +91,7 @@ public interface Connector extends ToXContentObject, Writeable {

void writeTo(StreamOutput out) throws IOException;

void update(MLCreateConnectorInput updateContent, Function<String, String> function);
void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function);

<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -302,7 +302,7 @@ public void writeTo(StreamOutput out) throws IOException {
}

@Override
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
Expand All @@ -320,7 +320,7 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
}
if (updateContent.getCredential() != null && !updateContent.getCredential().isEmpty()) {
this.credential = updateContent.getCredential();
encrypt(function);
encrypt(function, this.tenantId);
}
if (updateContent.getActions() != null) {
this.actions = updateContent.getActions();
Expand Down Expand Up @@ -379,10 +379,10 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
}

@Override
public void decrypt(String action, Function<String, String> function) {
public void decrypt(String action, BiFunction<String, String, String> function, String tenantId) {
Map<String, String> decrypted = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key)));
decrypted.put(key, function.apply(credential.get(key), tenantId));
}
this.decryptedCredential = decrypted;
Optional<ConnectorAction> connectorAction = findAction(action);
Expand All @@ -402,9 +402,9 @@ public Connector cloneConnector() {
}

@Override
public void encrypt(Function<String, String> function) {
public void encrypt(BiFunction<String, String, String> function, String tenantId) {
for (String key : credential.keySet()) {
String encrypted = function.apply(credential.get(key));
String encrypted = function.apply(credential.get(key), tenantId);
credential.put(key, encrypted);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
package org.opensearch.ml.common.transport.config;

import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
Expand All @@ -26,21 +28,29 @@
public class MLConfigGetRequest extends ActionRequest {

String configId;
String tenantId;

@Builder
public MLConfigGetRequest(String configId) {
public MLConfigGetRequest(String configId, String tenantId) {
this.configId = configId;
this.tenantId = tenantId;
}

public MLConfigGetRequest(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.configId = in.readString();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Version streamOutputVersion = out.getVersion();
out.writeString(this.configId);
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -371,4 +374,21 @@ public static void validateSchema(String schemaString, String instanceString) {
throw new OpenSearchParseException("Schema validation failed: " + e.getMessage(), e);
}
}

public static String hashString(String input) {
try {
// Create a MessageDigest instance for SHA-256
MessageDigest digest = MessageDigest.getInstance("SHA-256");

// Perform the hashing and get the byte array
byte[] hashBytes = digest.digest(input.getBytes());

// Convert the byte array to a Base64 encoded string
return Base64.getUrlEncoder().encodeToString(hashBytes);

} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Error: Unable to compute hash", e);
}
}

}
Loading
Loading