From cd006f1bc50696f56fe8cdc3798e32285dae232d Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Tue, 28 Jan 2025 12:25:01 -0800 Subject: [PATCH] adding multi-tenancy to config api and master key related changes (#3439) * adding multi-tenancy to config api and master key related changes Signed-off-by: Dhrubo Saha * adding more unit tests Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha (cherry picked from commit 9846e6e2533f814b0d90c08e784edda4f10f53ce) --- common/build.gradle | 2 + .../org/opensearch/ml/common/MLConfig.java | 21 +- .../ml/common/connector/Connector.java | 8 +- .../ml/common/connector/HttpConnector.java | 14 +- .../transport/config/MLConfigGetRequest.java | 12 +- .../ml/common/utils/StringUtils.java | 20 ++ .../opensearch/ml/common/MLConfigTest.java | 208 ++++++++++++++++++ .../ml/common/connector/AwsConnectorTest.java | 22 +- .../common/connector/HttpConnectorTest.java | 14 +- .../config/MLConfigGetRequestTest.java | 82 ++++++- .../config/MLConfigGetResponseTest.java | 4 +- .../org/opensearch/ml/engine/MLEngine.java | 11 +- .../engine/algorithms/remote/RemoteModel.java | 3 +- .../ml/engine/encryptor/Encryptor.java | 16 +- .../ml/engine/encryptor/EncryptorImpl.java | 93 +++++--- .../opensearch/ml/engine/MLEngineTest.java | 8 +- .../MetricsCorrelationTest.java | 2 +- .../QuestionAnsweringModelTest.java | 2 +- .../remote/AwsConnectorExecutorTest.java | 56 ++--- .../remote/RemoteConnectorExecutorTest.java | 4 +- .../algorithms/remote/RemoteModelTest.java | 6 +- .../TextEmbeddingSparseEncodingModelTest.java | 2 +- .../text_embedding/ModelHelperTest.java | 2 +- .../TextEmbeddingDenseModelTest.java | 2 +- .../TextSimilarityCrossEncoderModelTest.java | 2 +- .../tokenize/SparseTokenizerModelTest.java | 2 +- .../engine/encryptor/EncryptorImplTest.java | 202 ++++++++++++----- .../config/GetConfigTransportAction.java | 19 +- .../ExecuteConnectorTransportAction.java | 3 +- .../TransportCreateConnectorAction.java | 2 +- .../TransportDeployModelOnNodeAction.java | 8 +- .../forward/TransportForwardAction.java | 2 +- .../tasks/CancelBatchJobTransportAction.java | 21 +- .../action/tasks/GetTaskTransportAction.java | 3 +- .../opensearch/ml/cluster/MLSyncUpCron.java | 8 +- .../ml/model/MLModelCacheHelper.java | 3 +- .../opensearch/ml/model/MLModelManager.java | 6 +- .../ml/plugin/MachineLearningPlugin.java | 2 +- .../ml/rest/RestMLGetConfigAction.java | 11 +- .../ml/rest/RestMLPredictionAction.java | 7 +- .../ml/task/MLPredictTaskRunner.java | 27 ++- .../config/GetConfigTransportActionTests.java | 23 +- .../UpdateConnectorTransportActionTests.java | 2 +- .../TransportDeployModelActionTests.java | 2 +- .../TransportSyncUpOnNodeActionTests.java | 4 +- .../ml/cluster/MLSyncUpCronTests.java | 34 ++- .../ml/model/MLModelCacheHelperTests.java | 6 - .../ml/model/MLModelManagerTests.java | 2 +- .../ml/rest/RestMLGetConfigActionTests.java | 26 ++- .../ml/rest/RestMLPredictionActionTests.java | 32 +++ .../ml/task/MLExecuteTaskRunnerTests.java | 2 +- .../ml/task/MLPredictTaskRunnerTests.java | 6 +- .../MLTrainAndPredictTaskRunnerTests.java | 2 +- .../ml/task/MLTrainingTaskRunnerTests.java | 2 +- 54 files changed, 814 insertions(+), 271 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/MLConfigTest.java diff --git a/common/build.gradle b/common/build.gradle index 81245a556e..6424947572 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -39,6 +39,8 @@ dependencies { compileOnly("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") compileOnly("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' + // Multi-tenant SDK Client + compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}" } lombok { diff --git a/common/src/main/java/org/opensearch/ml/common/MLConfig.java b/common/src/main/java/org/opensearch/ml/common/MLConfig.java index 20bc1853d7..23a415483d 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/MLConfig.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.time.Instant; @@ -57,6 +59,7 @@ public class MLConfig implements ToXContentObject, Writeable { private final Instant createTime; private Instant lastUpdateTime; private Instant lastUpdatedTime; + private final String tenantId; @Builder(toBuilder = true) public MLConfig( @@ -66,7 +69,8 @@ public MLConfig( Configuration mlConfiguration, Instant createTime, Instant lastUpdateTime, - Instant lastUpdatedTime + Instant lastUpdatedTime, + String tenantId ) { this.type = type; this.configType = configType; @@ -75,6 +79,7 @@ public MLConfig( this.createTime = createTime; this.lastUpdateTime = lastUpdateTime; this.lastUpdatedTime = lastUpdatedTime; + this.tenantId = tenantId; } public MLConfig(StreamInput input) throws IOException { @@ -92,6 +97,7 @@ public MLConfig(StreamInput input) throws IOException { } lastUpdatedTime = input.readOptionalInstant(); } + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; } @Override @@ -116,6 +122,9 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalInstant(lastUpdatedTime); } + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override @@ -133,12 +142,14 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (lastUpdateTime != null || lastUpdatedTime != null) { builder.field(LAST_UPDATE_TIME_FIELD, lastUpdatedTime == null ? lastUpdateTime.toEpochMilli() : lastUpdatedTime.toEpochMilli()); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } return builder.endObject(); } public static MLConfig fromStream(StreamInput in) throws IOException { - MLConfig mlConfig = new MLConfig(in); - return mlConfig; + return new MLConfig(in); } public static MLConfig parse(XContentParser parser) throws IOException { @@ -149,6 +160,7 @@ public static MLConfig parse(XContentParser parser) throws IOException { Instant createTime = null; Instant lastUpdateTime = null; Instant lastUpdatedTime = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -177,6 +189,8 @@ public static MLConfig parse(XContentParser parser) throws IOException { case LAST_UPDATED_TIME_FIELD: lastUpdatedTime = Instant.ofEpochMilli(parser.longValue()); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); default: parser.skipChildren(); break; @@ -191,6 +205,7 @@ public static MLConfig parse(XContentParser parser) throws IOException { .createTime(createTime) .lastUpdateTime(lastUpdateTime) .lastUpdatedTime(lastUpdatedTime) + .tenantId(tenantId) .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 1bdb6747d1..d8306882a5 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -16,7 +16,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; +import java.util.function.BiFunction; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -79,9 +79,9 @@ public interface Connector extends ToXContentObject, Writeable { T createPayload(String action, Map parameters); - void decrypt(String action, Function function); + void decrypt(String action, BiFunction function, String tenantId); - void encrypt(Function function); + void encrypt(BiFunction function, String tenantId); Connector cloneConnector(); @@ -91,7 +91,7 @@ public interface Connector extends ToXContentObject, Writeable { void writeTo(StreamOutput out) throws IOException; - void update(MLCreateConnectorInput updateContent, Function function); + void update(MLCreateConnectorInput updateContent, BiFunction function); void parseResponse(T orElse, List modelTensors, boolean b) throws IOException; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index c33a401b04..2e2f56c7b7 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -21,7 +21,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; +import java.util.function.BiFunction; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -300,7 +300,7 @@ public void writeTo(StreamOutput out) throws IOException { } @Override - public void update(MLCreateConnectorInput updateContent, Function function) { + public void update(MLCreateConnectorInput updateContent, BiFunction function) { if (updateContent.getName() != null) { this.name = updateContent.getName(); } @@ -318,7 +318,7 @@ public void update(MLCreateConnectorInput updateContent, Function findStringParametersWithNullDefaultValue(String input) { } @Override - public void decrypt(String action, Function function) { + public void decrypt(String action, BiFunction function, String tenantId) { Map decrypted = new HashMap<>(); for (String key : credential.keySet()) { - decrypted.put(key, function.apply(credential.get(key))); + decrypted.put(key, function.apply(credential.get(key), tenantId)); } this.decryptedCredential = decrypted; Optional connectorAction = findAction(action); @@ -400,9 +400,9 @@ public Connector cloneConnector() { } @Override - public void encrypt(Function function) { + public void encrypt(BiFunction function, String tenantId) { for (String key : credential.keySet()) { - String encrypted = function.apply(credential.get(key)); + String encrypted = function.apply(credential.get(key), tenantId); credential.put(key, encrypted); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java index bfc1c156db..dbdc8afe9c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java @@ -6,12 +6,14 @@ package org.opensearch.ml.common.transport.config; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -26,21 +28,29 @@ public class MLConfigGetRequest extends ActionRequest { String configId; + String tenantId; @Builder - public MLConfigGetRequest(String configId) { + public MLConfigGetRequest(String configId, String tenantId) { this.configId = configId; + this.tenantId = tenantId; } public MLConfigGetRequest(StreamInput in) throws IOException { super(in); + Version streamInputVersion = in.getVersion(); this.configId = in.readString(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + Version streamOutputVersion = out.getVersion(); out.writeString(this.configId); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index efa627f4e4..4fd9332519 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -8,9 +8,12 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.AccessController; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Base64; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -477,4 +480,21 @@ public static void validateSchema(String schemaString, String instanceString) { throw new OpenSearchParseException("Schema validation failed: " + e.getMessage(), e); } } + + public static String hashString(String input) { + try { + // Create a MessageDigest instance for SHA-256 + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + + // Perform the hashing and get the byte array + byte[] hashBytes = digest.digest(input.getBytes()); + + // Convert the byte array to a Base64 encoded string + return Base64.getUrlEncoder().encodeToString(hashBytes); + + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("Error: Unable to compute hash", e); + } + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/MLConfigTest.java b/common/src/test/java/org/opensearch/ml/common/MLConfigTest.java new file mode 100644 index 0000000000..67187cc70f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/MLConfigTest.java @@ -0,0 +1,208 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class MLConfigTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void toXContent_Minimal() throws IOException { + MLConfig config = MLConfig.builder().type("test_type").build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + config.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + Assert.assertEquals("{\"type\":\"test_type\"}", content); + } + + @Test + public void toXContent_Full() throws IOException { + Instant now = Instant.now(); + Configuration configuration = Configuration.builder().build(); + MLConfig config = MLConfig + .builder() + .type("test_type") + .configType("test_config_type") + .configuration(configuration) + .mlConfiguration(configuration) + .createTime(now) + .lastUpdateTime(now) + .lastUpdatedTime(now) + .tenantId("test_tenant") + .build(); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + config.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + Assert + .assertTrue( + content.contains("\"type\":\"test_config_type\"") + && content.contains("\"configuration\":") + && content.contains("\"create_time\":" + now.toEpochMilli()) + && content.contains("\"last_update_time\":" + now.toEpochMilli()) + && content.contains("\"tenant_id\":\"test_tenant\"") + ); + } + + @Test + public void parse_Minimal() throws IOException { + String jsonStr = "{\"type\":\"test_type\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLConfig config = MLConfig.parse(parser); + Assert.assertEquals("test_type", config.getType()); + Assert.assertNull(config.getConfigType()); + Assert.assertNull(config.getConfiguration()); + Assert.assertNull(config.getMlConfiguration()); + Assert.assertNull(config.getCreateTime()); + Assert.assertNull(config.getLastUpdateTime()); + Assert.assertNull(config.getLastUpdatedTime()); + Assert.assertNull(config.getTenantId()); + } + + @Test + public void parse_Full() throws IOException { + String jsonStr = "{\"type\":\"test_type\",\"config_type\":\"test_config_type\"," + + "\"configuration\":{},\"ml_configuration\":{},\"create_time\":1672531200000," + + "\"last_update_time\":1672534800000,\"last_updated_time\":1672538400000,\"tenant_id\":\"test_tenant\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLConfig config = MLConfig.parse(parser); + Assert.assertEquals("test_type", config.getType()); + Assert.assertEquals("test_config_type", config.getConfigType()); + Assert.assertNotNull(config.getConfiguration()); + Assert.assertNotNull(config.getMlConfiguration()); + Assert.assertEquals(Instant.ofEpochMilli(1672531200000L), config.getCreateTime()); + Assert.assertEquals(Instant.ofEpochMilli(1672534800000L), config.getLastUpdateTime()); + Assert.assertEquals(Instant.ofEpochMilli(1672538400000L), config.getLastUpdatedTime()); + Assert.assertEquals("test_tenant", config.getTenantId()); + } + + @Test + public void writeToAndReadFrom() throws IOException { + Instant now = Instant.now(); + Configuration configuration = Configuration.builder().build(); + MLConfig originalConfig = MLConfig + .builder() + .type("test_type") + .configType("test_config_type") + .configuration(configuration) + .mlConfiguration(configuration) + .createTime(now) + .lastUpdateTime(now) + .lastUpdatedTime(now) + .tenantId("test_tenant") + .build(); + + BytesStreamOutput output = new BytesStreamOutput(); + originalConfig.writeTo(output); + + MLConfig deserializedConfig = new MLConfig(output.bytes().streamInput()); + Assert.assertEquals("test_type", deserializedConfig.getType()); + Assert.assertEquals("test_config_type", deserializedConfig.getConfigType()); + Assert.assertNotNull(deserializedConfig.getConfiguration()); + Assert.assertNotNull(deserializedConfig.getMlConfiguration()); + Assert.assertEquals(now, deserializedConfig.getCreateTime()); + Assert.assertEquals(now, deserializedConfig.getLastUpdateTime()); + Assert.assertEquals(now, deserializedConfig.getLastUpdatedTime()); + Assert.assertEquals("test_tenant", deserializedConfig.getTenantId()); + } + + @Test + public void writeToAndReadFrom_Minimal() throws IOException { + MLConfig originalConfig = MLConfig.builder().type("test_type").build(); + + BytesStreamOutput output = new BytesStreamOutput(); + originalConfig.writeTo(output); + + MLConfig deserializedConfig = new MLConfig(output.bytes().streamInput()); + Assert.assertEquals("test_type", deserializedConfig.getType()); + Assert.assertNull(deserializedConfig.getConfigType()); + Assert.assertNull(deserializedConfig.getConfiguration()); + Assert.assertNull(deserializedConfig.getMlConfiguration()); + Assert.assertNull(deserializedConfig.getCreateTime()); + Assert.assertNull(deserializedConfig.getLastUpdateTime()); + Assert.assertNull(deserializedConfig.getLastUpdatedTime()); + Assert.assertNull(deserializedConfig.getTenantId()); + } + + @Test + public void crossVersionSerialization_NoTenantId() throws IOException { + // Simulate an older version (before VERSION_2_19_0) + Version oldVersion = Version.V_2_18_0; + + // Create an MLConfig instance with tenantId set + MLConfig originalConfig = MLConfig.builder().type("test_type").tenantId("test_tenant").build(); + + // Serialize using the older version + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(oldVersion); + originalConfig.writeTo(output); + + // Deserialize and verify tenantId is not present + StreamInput input = output.bytes().streamInput(); + input.setVersion(oldVersion); + MLConfig deserializedConfig = new MLConfig(input); + + Assert.assertEquals("test_type", deserializedConfig.getType()); + Assert.assertNull(deserializedConfig.getTenantId()); + } + + @Test + public void crossVersionSerialization_WithTenantId() throws IOException { + // Simulate a newer version (on or after VERSION_2_19_0) + Version newVersion = Version.V_2_19_0; + + // Create an MLConfig instance with tenantId set + MLConfig originalConfig = MLConfig.builder().type("test_type").tenantId("test_tenant").build(); + + // Serialize using the newer version + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(newVersion); + originalConfig.writeTo(output); + + // Deserialize and verify tenantId is present + StreamInput input = output.bytes().streamInput(); + input.setVersion(newVersion); + MLConfig deserializedConfig = new MLConfig(input); + + Assert.assertEquals("test_type", deserializedConfig.getType()); + Assert.assertEquals("test_tenant", deserializedConfig.getTenantId()); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index a60d3ac1cf..2b679b8bbe 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -18,7 +18,7 @@ import java.util.HashMap; import java.util.Locale; import java.util.Map; -import java.util.function.Function; +import java.util.function.BiFunction; import org.junit.Assert; import org.junit.Before; @@ -39,13 +39,13 @@ public class AwsConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - Function encryptFunction; - Function decryptFunction; + BiFunction encryptFunction; + BiFunction decryptFunction; @Before public void setUp() { - encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT); - decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT); + encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT); } @Test @@ -115,8 +115,8 @@ public void constructor_NoPredictAction() { .build(); Assert.assertNotNull(connector); - connector.encrypt(encryptFunction); - connector.decrypt(PREDICT.name(), decryptFunction); + connector.encrypt(encryptFunction, null); + connector.decrypt(PREDICT.name(), decryptFunction, null); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals(null, connector.getSessionToken()); @@ -159,8 +159,8 @@ public void constructor() { String url = "https://${parameters.endpoint}/model1"; AwsConnector connector = createAwsConnector(parameters, credential, url); - connector.encrypt(encryptFunction); - connector.decrypt(PREDICT.name(), decryptFunction); + connector.encrypt(encryptFunction, null); + connector.decrypt(PREDICT.name(), decryptFunction, null); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); @@ -180,8 +180,8 @@ public void constructor_NoParameter() { String url = "https://test.com"; AwsConnector connector = createAwsConnector(null, credential, url); - connector.encrypt(encryptFunction); - connector.decrypt(PREDICT.name(), decryptFunction); + connector.encrypt(encryptFunction, null); + connector.decrypt(PREDICT.name(), decryptFunction, null); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 7f83444c66..17756a6736 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -15,7 +15,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.function.Function; +import java.util.function.BiFunction; import org.junit.Assert; import org.junit.Before; @@ -38,8 +38,8 @@ public class HttpConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - Function encryptFunction; - Function decryptFunction; + BiFunction encryptFunction; + BiFunction decryptFunction; String TEST_CONNECTOR_JSON_STRING = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," @@ -55,8 +55,8 @@ public class HttpConnectorTest { @Before public void setUp() { - encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT); - decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT); + encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT); } @Test @@ -124,7 +124,7 @@ public void cloneConnector() { @Test public void decrypt() { HttpConnector connector = createHttpConnector(); - connector.decrypt(PREDICT.name(), decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction, null); Map decryptedCredential = connector.getDecryptedCredential(); Assert.assertEquals(1, decryptedCredential.size()); Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key")); @@ -141,7 +141,7 @@ public void decrypt() { @Test public void encrypted() { HttpConnector connector = createHttpConnector(); - connector.encrypt(encryptFunction); + connector.encrypt(encryptFunction, null); Map credential = connector.getCredential(); Assert.assertEquals(1, credential.size()); Assert.assertEquals("encrypted: test_key_value", credential.get("key")); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java index ea16005d14..4524cb65a9 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java @@ -5,24 +5,28 @@ package org.opensearch.ml.common.transport.config; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.opensearch.action.ValidateActions.addValidationError; import java.io.IOException; import java.io.UncheckedIOException; import org.junit.Test; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; public class MLConfigGetRequestTest { String configId; + String tenantId = null; @Test public void constructor_configId() { configId = "test-abc"; - MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); assertEquals(mlConfigGetRequest.getConfigId(), configId); } @@ -30,7 +34,7 @@ public void constructor_configId() { public void writeTo() throws IOException { configId = "test-hij"; - MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); BytesStreamOutput output = new BytesStreamOutput(); mlConfigGetRequest.writeTo(output); @@ -43,7 +47,7 @@ public void writeTo() throws IOException { @Test public void validate_Success() { configId = "not-null"; - MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); assertEquals(null, mlConfigGetRequest.validate()); } @@ -51,7 +55,7 @@ public void validate_Success() { @Test public void validate_Failure() { configId = null; - MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); assertEquals(null, mlConfigGetRequest.configId); ActionRequestValidationException exception = addValidationError("ML config id can't be null", null); @@ -61,14 +65,14 @@ public void validate_Failure() { @Test public void fromActionRequest_Success() throws IOException { configId = "test-lmn"; - MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); assertEquals(mlConfigGetRequest.fromActionRequest(mlConfigGetRequest), mlConfigGetRequest); } @Test public void fromActionRequest_Success_fromActionRequest() throws IOException { configId = "test-opq"; - MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); ActionRequest actionRequest = new ActionRequest() { @Override @@ -88,7 +92,7 @@ public void writeTo(StreamOutput out) throws IOException { @Test(expected = UncheckedIOException.class) public void fromActionRequest_IOException() { configId = "test-rst"; - MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -102,4 +106,68 @@ public void writeTo(StreamOutput out) throws IOException { }; mlConfigGetRequest.fromActionRequest(actionRequest); } + + @Test + public void writeTo_WithTenantId() throws IOException { + configId = "test-with-tenant"; + tenantId = "test_tenant"; + + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); + BytesStreamOutput output = new BytesStreamOutput(); + mlConfigGetRequest.writeTo(output); + + MLConfigGetRequest deserializedRequest = new MLConfigGetRequest(output.bytes().streamInput()); + + assertEquals(mlConfigGetRequest.getConfigId(), deserializedRequest.getConfigId()); + assertEquals(mlConfigGetRequest.getTenantId(), deserializedRequest.getTenantId()); + assertEquals(tenantId, deserializedRequest.getTenantId()); + } + + @Test + public void crossVersionSerialization_WithoutTenantIdForOldVersion() throws IOException { + configId = "test-no-tenant"; + tenantId = "test_tenant"; + + // Simulate an older version (before VERSION_2_19_0) + Version oldVersion = Version.V_2_18_0; + + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(oldVersion); // Set the version for the output + mlConfigGetRequest.writeTo(output); + + // Set the version for the input to match the older version + StreamInput input = output.bytes().streamInput(); + input.setVersion(oldVersion); // Important to match the output version + + MLConfigGetRequest deserializedRequest = new MLConfigGetRequest(input); + + // Validate that the configId is correctly deserialized and tenantId is null + assertEquals(configId, deserializedRequest.getConfigId()); + assertNull(deserializedRequest.getTenantId()); // tenantId should not be present for old versions + } + + @Test + public void fromActionRequest_WithTenantId() throws IOException { + configId = "test-with-tenant"; + tenantId = "test_tenant"; + + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId); + + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlConfigGetRequest.writeTo(out); + } + }; + MLConfigGetRequest deserializedRequest = mlConfigGetRequest.fromActionRequest(actionRequest); + + assertEquals(configId, deserializedRequest.getConfigId()); + assertEquals(tenantId, deserializedRequest.getTenantId()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java index 5e24bcaa6e..359d9850f7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java @@ -60,7 +60,7 @@ public void MLConfigGetResponse_Builder() throws IOException { @Test public void writeTo() throws IOException { // create ml agent using mlConfig and mlConfigGetResponse - mlConfig = new MLConfig("olly_agent", null, new Configuration("agent_id"), null, Instant.EPOCH, Instant.EPOCH, Instant.EPOCH); + mlConfig = new MLConfig("olly_agent", null, new Configuration("agent_id"), null, Instant.EPOCH, Instant.EPOCH, Instant.EPOCH, null); MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); // use write out for both agents BytesStreamOutput output = new BytesStreamOutput(); @@ -76,7 +76,7 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { - mlConfig = new MLConfig(null, null, null, null, null, null, null); + mlConfig = new MLConfig(null, null, null, null, null, null, null, null); MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); ToXContent.Params params = EMPTY_PARAMS; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index ecbc05c43e..419e573ba9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -125,7 +125,12 @@ public MLModel train(Input input) { } public Map getConnectorCredential(Connector connector) { - connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential)); + connector + .decrypt( + PREDICT.name(), + (credential, tenantId) -> encryptor.decrypt(credential, connector.getTenantId()), + connector.getTenantId() + ); Map decryptedCredential = connector.getDecryptedCredential(); String region = connector.getParameters().get(REGION_FIELD); if (region != null) { @@ -211,8 +216,8 @@ private void validateInput(Input input) { } } - public String encrypt(String credential) { - return encryptor.encrypt(credential); + public String encrypt(String credential, String tenantId) { + return encryptor.encrypt(credential, tenantId); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index f43f0ca0c3..6336425785 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -101,7 +101,8 @@ public boolean isModelReady() { public void initModel(MLModel model, Map params, Encryptor encryptor) { try { Connector connector = model.getConnector().cloneConnector(); - connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential)); + connector + .decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, model.getTenantId()), model.getTenantId()); this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java index 7a37ad355e..5b6d274756 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java @@ -11,26 +11,32 @@ public interface Encryptor { * Takes plaintext and returns encrypted text. * * @param plainText plainText. + * @param tenantId id of the tenant * @return String encryptedText. */ - String encrypt(String plainText); + String encrypt(String plainText, String tenantId); /** * Takes encryptedText and returns plain text. * * @param encryptedText encryptedText. + * @param tenantId id of the tenant * @return String plainText. */ - String decrypt(String encryptedText); + String decrypt(String encryptedText, String tenantId); /** * Set up the masterKey for dynamic updating - * + * @param tenantId ID of the tenant * @param masterKey masterKey to be set. */ - void setMasterKey(String masterKey); + void setMasterKey(String tenantId, String masterKey); - String getMasterKey(); + /** + * Get the masterKey + * @param tenantId ID of the tenant + */ + String getMasterKey(String tenantId); String generateMasterKey(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index ced2d70fc5..42864e7519 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -8,12 +8,17 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.MLConfig.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.hashString; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; import java.time.Instant; import java.util.Base64; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; @@ -48,37 +53,41 @@ public class EncryptorImpl implements Encryptor { "The ML encryption master key has not been initialized yet. Please retry after waiting for 10 seconds."; private ClusterService clusterService; private Client client; - private volatile String masterKey; + private final Map tenantMasterKeys; private MLIndicesHandler mlIndicesHandler; + // concurrent map can't have null as a key. This is to support single tenancy + // assigning some random string so that it can't be duplicate + public static final String DEFAULT_TENANT_ID = "03000200-0400-0500-0006-000700080009"; + public EncryptorImpl(ClusterService clusterService, Client client, MLIndicesHandler mlIndicesHandler) { - this.masterKey = null; + this.tenantMasterKeys = new ConcurrentHashMap<>(); this.clusterService = clusterService; this.client = client; + this.mlIndicesHandler = mlIndicesHandler; } - public EncryptorImpl(String masterKey) { - this.masterKey = masterKey; + public EncryptorImpl(String tenantId, String masterKey) { + this.tenantMasterKeys = new ConcurrentHashMap<>(); + this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), masterKey); } @Override - public void setMasterKey(String masterKey) { - this.masterKey = masterKey; + public void setMasterKey(String tenantId, String masterKey) { + this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), masterKey); } @Override - public String getMasterKey() { - return masterKey; + public String getMasterKey(String tenantId) { + return tenantMasterKeys.get(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID)); } @Override - public String encrypt(String plainText) { - initMasterKey(); + public String encrypt(String plainText, String tenantId) { + initMasterKey(tenantId); final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); - byte[] bytes = Base64.getDecoder().decode(masterKey); - // https://github.com/aws/aws-encryption-sdk-java/issues/1879 - JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING"); + JceMasterKey jceMasterKey = createJceMasterKey(tenantId); final CryptoResult encryptResult = crypto .encryptData(jceMasterKey, plainText.getBytes(StandardCharsets.UTF_8)); @@ -86,12 +95,10 @@ public String encrypt(String plainText) { } @Override - public String decrypt(String encryptedText) { - initMasterKey(); + public String decrypt(String encryptedText, String tenantId) { + initMasterKey(tenantId); final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); - - byte[] bytes = Base64.getDecoder().decode(masterKey); - JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING"); + JceMasterKey jceMasterKey = createJceMasterKey(tenantId); final CryptoResult decryptedResult = crypto .decryptData(jceMasterKey, Base64.getDecoder().decode(encryptedText)); @@ -102,47 +109,64 @@ public String decrypt(String encryptedText) { public String generateMasterKey() { byte[] keyBytes = new byte[32]; new SecureRandom().nextBytes(keyBytes); - String base64Key = Base64.getEncoder().encodeToString(keyBytes); - return base64Key; + return Base64.getEncoder().encodeToString(keyBytes); + } + + private JceMasterKey createJceMasterKey(String tenantId) { + byte[] bytes = Base64.getDecoder().decode(tenantMasterKeys.get(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID))); + return JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING"); } - private void initMasterKey() { - if (masterKey != null) { + private void initMasterKey(String tenantId) { + if (tenantMasterKeys.containsKey(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID))) { return; } AtomicReference exceptionRef = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { if (!r) { exceptionRef.set(new RuntimeException("No response to create ML Config index")); latch.countDown(); } else { - GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + String masterKeyId = MASTER_KEY; + if (tenantId != null) { + masterKeyId = MASTER_KEY + "_" + hashString(tenantId); + } + final String MASTER_KEY_ID = masterKeyId; + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY_ID); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getRequest, ActionListener.wrap(getResponse -> { if (getResponse == null || !getResponse.isExists()) { - IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY_ID); final String generatedMasterKey = generateMasterKey(); - indexRequest - .source(ImmutableMap.of(MASTER_KEY, generatedMasterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + + ImmutableMap.Builder mapBuilder = ImmutableMap.builder(); + mapBuilder.put(MASTER_KEY_ID, generatedMasterKey); + mapBuilder.put(CREATE_TIME_FIELD, Instant.now().toEpochMilli()); + if (tenantId != null) { + mapBuilder.put(TENANT_ID_FIELD, tenantId); + } + indexRequest.source(mapBuilder.build()); indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); indexRequest.opType(DocWriteRequest.OpType.CREATE); client.index(indexRequest, ActionListener.wrap(indexResponse -> { - this.masterKey = generatedMasterKey; + this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), generatedMasterKey); log.info("ML encryption master key initialized successfully"); latch.countDown(); }, e -> { if (ExceptionUtils.getRootCause(e) instanceof VersionConflictEngineException) { - GetRequest getMasterKeyRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + GetRequest getMasterKeyRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY_ID); try ( ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext() ) { client.get(getMasterKeyRequest, ActionListener.wrap(getMasterKeyResponse -> { if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) { - final String masterKey = (String) getMasterKeyResponse.getSourceAsMap().get(MASTER_KEY); - this.masterKey = masterKey; + this.tenantMasterKeys + .put( + Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), + (String) getMasterKeyResponse.getSourceAsMap().get(MASTER_KEY_ID) + ); log.info("ML encryption master key already initialized, no action needed"); latch.countDown(); } else { @@ -162,8 +186,8 @@ private void initMasterKey() { } })); } else { - final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); - this.masterKey = masterKey; + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY_ID); + this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), masterKey); log.info("ML encryption master key already initialized, no action needed"); latch.countDown(); } @@ -197,9 +221,8 @@ private void initMasterKey() { throw new MLException(exceptionRef.get()); } } - if (masterKey == null) { + if (tenantMasterKeys.get(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID)) == null) { throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR); } } - } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 95c3d5d218..30d4902f1b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -65,7 +65,7 @@ public class MLEngineTest extends MLStaticMockBase { @Before public void setUp() { - Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + Encryptor encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); } @@ -389,7 +389,7 @@ public void getModelCachePath_ReturnsCorrectPath() { @Test public void testMLEngineInitialization() { Path testPath = Path.of("/tmp/test" + UUID.randomUUID()); - mlEngine = new MLEngine(testPath, new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); + mlEngine = new MLEngine(testPath, new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); Path expectedMlCachePath = testPath.resolve("ml_cache"); Path expectedMlConfigPath = expectedMlCachePath.resolve("config"); @@ -411,14 +411,14 @@ public void testPredictWithInvalidInput() { @Test public void testEncryptMethod() { String testString = "testString"; - String encryptedString = mlEngine.encrypt(testString); + String encryptedString = mlEngine.encrypt(testString, null); assertNotNull(encryptedString); assertNotEquals(testString, encryptedString); } @Test public void testGetConnectorCredential() throws IOException { - String encryptedValue = mlEngine.encrypt("test_key_value"); + String encryptedValue = mlEngine.encrypt("test_key_value", null); String test_connector_string = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + "\"parameters\":{\"region\":\"test region\"},\"credential\":{\"key\":\"" diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 477ed75ebe..e5cc2a94fd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -182,7 +182,7 @@ public void setUp() throws IOException, URISyntaxException { System.setProperty("testMode", "true"); mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); modelConfig = MetricsCorrelationModelConfig.builder().modelType(MetricsCorrelation.MODEL_TYPE).allConfig(null).build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java index 9869b42def..74e0c0b389 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java @@ -69,7 +69,7 @@ public class QuestionAnsweringModelTest { @Before public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); model = MLModel .builder() diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 98d5feb7ba..0227f2f12c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -105,7 +105,7 @@ public class AwsConnectorExecutorTest { @Before public void setUp() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); when(scriptService.compile(any(), any())) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}")); } @@ -140,7 +140,7 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); Connector connector = AwsConnector .awsConnectorBuilder() @@ -152,7 +152,7 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -184,7 +184,7 @@ public void executePredict_TextDocsInferenceInput() { .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); Connector connector = AwsConnector .awsConnectorBuilder() @@ -195,7 +195,7 @@ public void executePredict_TextDocsInferenceInput() { .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -223,7 +223,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "2"); Connector connector = AwsConnector @@ -236,7 +236,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -275,7 +275,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "1"); Connector connector = AwsConnector @@ -288,7 +288,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -335,7 +335,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "1"); Connector connector = AwsConnector @@ -348,7 +348,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -392,7 +392,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "1"); Connector connector = AwsConnector @@ -405,7 +405,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -451,7 +451,7 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); Connector connector = AwsConnector .awsConnectorBuilder() @@ -462,7 +462,7 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor0 = new AwsConnectorExecutor(connector); Field httpClientField = AwsConnectorExecutor.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); @@ -496,7 +496,7 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "-1"); Connector connector = AwsConnector @@ -508,7 +508,7 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -539,7 +539,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); Connector connector = AwsConnector .awsConnectorBuilder() @@ -549,7 +549,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio .parameters(parameters) .credential(credential) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -583,7 +583,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre ) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); Connector connector = AwsConnector .awsConnectorBuilder() @@ -594,7 +594,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -623,7 +623,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); Connector connector = AwsConnector .awsConnectorBuilder() @@ -634,7 +634,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -662,7 +662,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces .requestBody("{\"input\": ${parameters.input}}") .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); Connector connector = AwsConnector .awsConnectorBuilder() @@ -673,7 +673,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -702,7 +702,7 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); Map parameters = ImmutableMap .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "5"); // execute with retry disabled @@ -717,7 +717,7 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(connectorClientConfig) .build(); - connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -751,7 +751,7 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(connectorClientConfig2) .build(); - connector2.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + connector2.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); executor.initialize(connector2); executor .executeAction( diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java index a5cfacdb21..f464f67bb6 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java @@ -67,7 +67,7 @@ public class RemoteConnectorExecutorTest { @Before public void setUp() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); when(scriptService.compile(any(), any())) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}")); } @@ -81,7 +81,7 @@ private Connector getConnector(Map parameters) { .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key", null), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key", null)); return AwsConnector .awsConnectorBuilder() .name("test connector") diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 075019834c..8bce0d6394 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -64,7 +64,7 @@ public class RemoteModelTest extends MLStaticMockBase { public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); - encryptor = spy(new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); + encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); } @Test @@ -179,7 +179,7 @@ private void initModel_Failure_With_Throwable( exceptionRule.expectMessage(expExceptionMessage); Connector connector = createConnector(null); when(mlModel.getConnector()).thenReturn(connector); - doThrow(actualException).when(encryptor).decrypt(any()); + doThrow(actualException).when(encryptor).decrypt(any(), any()); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); } @@ -222,7 +222,7 @@ private Connector createConnector(Map headers) { .name("test connector") .protocol(ConnectorProtocols.HTTP) .version("1") - .credential(ImmutableMap.of("key", encryptor.encrypt("test_api_key"))) + .credential(ImmutableMap.of("key", encryptor.encrypt("test_api_key", null))) .actions(Arrays.asList(predictAction)) .build(); return connector; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java index f5d145f21a..2fc6d9f89a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java @@ -72,7 +72,7 @@ public class TextEmbeddingSparseEncodingModelTest { @Before public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); modelId = "test_model_id"; modelName = "test_model_name"; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java index 3faafcc24f..f12618dd4f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java @@ -59,7 +59,7 @@ public void setup() throws URISyntaxException { MockitoAnnotations.openMocks(this); modelFormat = MLModelFormat.TORCH_SCRIPT; modelId = "model_id"; - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor); modelHelper = new ModelHelper(mlEngine); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 138faf65e1..cc141a4b51 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -80,7 +80,7 @@ public class TextEmbeddingDenseModelTest { @Before public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); modelId = "test_model_id"; modelName = "test_model_name"; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java index 88e64e2517..c09f2a42c0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java @@ -85,7 +85,7 @@ public class TextSimilarityCrossEncoderModelTest { @Before public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); model = MLModel .builder() diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java index 919a4e4565..eac4af205c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java @@ -68,7 +68,7 @@ public class SparseTokenizerModelTest { @Before public void setUp() throws URISyntaxException { mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(mlCachePath, encryptor); modelId = "test_model_id"; modelName = "test_model_name"; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java index 3f1c8c3948..415b6c72e3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java @@ -8,11 +8,14 @@ import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.common.utils.StringUtils.hashString; +import static org.opensearch.ml.engine.encryptor.EncryptorImpl.DEFAULT_TENANT_ID; import static org.opensearch.ml.engine.encryptor.EncryptorImpl.MASTER_KEY_NOT_READY_ERROR; import java.io.IOException; import java.time.Instant; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.junit.Assert; import org.junit.Before; @@ -32,10 +35,14 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.ThreadPool; @@ -57,17 +64,21 @@ public class EncryptorImplTest { @Mock private MLIndicesHandler mlIndicesHandler; - String masterKey; + Map masterKey; @Mock ThreadPool threadPool; ThreadContext threadContext; final String USER_STRING = "myuser|role1,role2|myTenant"; + final String TENANT_ID = "myTenant"; + + Encryptor encryptor; @Before public void setUp() { MockitoAnnotations.openMocks(this); - masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; + masterKey = new ConcurrentHashMap<>(); + masterKey.put(DEFAULT_TENANT_ID, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -108,29 +119,29 @@ public void setUp() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); } @Test - public void encrypt_ExistingMasterKey() { + public void encrypt_ExistingMasterKey() throws IOException { doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(0); actionListener.onResponse(true); return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); + + GetResponse response = prepareMLConfigResponse(null); doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(true); - when(response.getSourceAsMap()).thenReturn(Map.of(MASTER_KEY, masterKey)); actionListener.onResponse(response); return null; }).when(client).get(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - String encrypted = encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + String encrypted = encryptor.encrypt("test", null); Assert.assertNotNull(encrypted); - Assert.assertEquals(masterKey, encryptor.getMasterKey()); + Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); } @Test @@ -155,10 +166,10 @@ public void encrypt_NonExistingMasterKey() { }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - String encrypted = encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); + String encrypted = encryptor.encrypt("test", TENANT_ID); Assert.assertNotNull(encrypted); - Assert.assertNotEquals(masterKey, encryptor.getMasterKey()); + Assert.assertNotEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(TENANT_ID)); } @Test @@ -184,8 +195,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey() { }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); + encryptor.encrypt("test", TENANT_ID); } @Test @@ -211,8 +222,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_NonRuntimeExceptio }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); + encryptor.encrypt("test", TENANT_ID); } @Test @@ -245,8 +256,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict() }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); + encryptor.encrypt("test", TENANT_ID); } @Test @@ -278,8 +289,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_Nu }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); + encryptor.encrypt("test", TENANT_ID); } @Test @@ -311,28 +322,28 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_Nu }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + String encrypted = encryptor.encrypt("test", null); + Assert.assertNotNull(encrypted); + Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); } @Test - public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_GetExistingMasterKey() { + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_GetExistingMasterKey() throws IOException { doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(0); actionListener.onResponse(true); return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); + + GetResponse response = prepareMLConfigResponse(null); + doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); actionListener.onResponse(response); return null; }).doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(true); - when(response.getSourceAsMap()).thenReturn(Map.of(MASTER_KEY, masterKey)); actionListener.onResponse(response); return null; }).when(client).get(any(), any()); @@ -344,10 +355,10 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_Ge }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - String encrypted = encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + String encrypted = encryptor.encrypt("test", null); Assert.assertNotNull(encrypted); - Assert.assertEquals(masterKey, encryptor.getMasterKey()); + Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); } @Test @@ -378,10 +389,10 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_Fa }).when(client).index(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - String encrypted = encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + String encrypted = encryptor.encrypt("test", null); Assert.assertNotNull(encrypted); - Assert.assertEquals(masterKey, encryptor.getMasterKey()); + Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); } @Test @@ -390,7 +401,7 @@ public void encrypt_ThrowExceptionWhenInitMLConfigIndex() { exceptionRule.expectMessage("test exception"); doThrow(new RuntimeException("test exception")).when(mlIndicesHandler).initMLConfigIndex(any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - encryptor.encrypt(masterKey); + encryptor.encrypt(masterKey.get(DEFAULT_TENANT_ID), null); } @Test @@ -403,7 +414,7 @@ public void encrypt_FailedToInitMLConfigIndex() { return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - encryptor.encrypt(masterKey); + encryptor.encrypt(masterKey.get(DEFAULT_TENANT_ID), null); } @Test @@ -421,43 +432,43 @@ public void encrypt_FailedToGetMasterKey() { return null; }).when(client).get(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - encryptor.encrypt(masterKey); + encryptor.encrypt(masterKey.get(DEFAULT_TENANT_ID), null); } @Test public void encrypt_DifferentMasterKey() { - Encryptor encryptor = new EncryptorImpl(masterKey); - Assert.assertNotNull(encryptor.getMasterKey()); - String encrypted1 = encryptor.encrypt("test"); + Encryptor encryptor = new EncryptorImpl(null, masterKey.get(DEFAULT_TENANT_ID)); + String test = encryptor.getMasterKey(null); + Assert.assertNotNull(test); + String encrypted1 = encryptor.encrypt("test", null); - encryptor.setMasterKey(encryptor.generateMasterKey()); - String encrypted2 = encryptor.encrypt("test"); + encryptor.setMasterKey(null, encryptor.generateMasterKey()); + String encrypted2 = encryptor.encrypt("test", null); Assert.assertNotEquals(encrypted1, encrypted2); } @Test - public void decrypt() { + public void decrypt() throws IOException { doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(0); actionListener.onResponse(true); return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); + + GetResponse response = prepareMLConfigResponse(null); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(true); - when(response.getSourceAsMap()) - .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); listener.onResponse(response); return null; }).when(client).get(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - String encrypted = encryptor.encrypt("test"); - String decrypted = encryptor.decrypt(encrypted); + Assert.assertNull(encryptor.getMasterKey(null)); + String encrypted = encryptor.encrypt("test", null); + String decrypted = encryptor.decrypt(encrypted, null); Assert.assertEquals("test", decrypted); - Assert.assertEquals(masterKey, encryptor.getMasterKey()); + Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); } @Test @@ -474,8 +485,8 @@ public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() { }).when(client).get(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.encrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + encryptor.encrypt("test", null); } @Test @@ -495,8 +506,8 @@ public void decrypt_NullMasterKey_GetMasterKey_Exception() { }).when(client).get(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.decrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + encryptor.decrypt("test", null); } @Test @@ -511,8 +522,8 @@ public void decrypt_NoResponseToInitConfigIndex() { }).when(mlIndicesHandler).initMLConfigIndex(any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.decrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + encryptor.decrypt("test", null); } @Test @@ -530,7 +541,80 @@ public void decrypt_MLConfigIndexNotFound() { }).when(client).get(any(), any()); Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey()); - encryptor.decrypt("test"); + Assert.assertNull(encryptor.getMasterKey(null)); + encryptor.decrypt("test", null); + } + + @Test + public void initMasterKey_AddTenantMasterKeys() throws IOException { + // Mock ML Config Index initialization to succeed + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); // Simulate successful ML Config index initialization + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + + // Mock GetResponse to return a valid MASTER_KEY_ID for the given tenant + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = prepareMLConfigResponse(TENANT_ID); // Response includes dynamic MASTER_KEY_ID + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + // Initialize Encryptor and verify no master key exists initially + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); + + // Encrypt using the specified tenant ID + String encrypted = encryptor.encrypt("test", TENANT_ID); + Assert.assertNotNull(encrypted); + + // Verify that the tenant-specific master key is added + String tenantMasterKey = encryptor.getMasterKey(TENANT_ID); + Assert.assertNotNull(tenantMasterKey); + + // Ensure that the master key for this tenant matches the expected value + String expectedMasterKeyId = MASTER_KEY + "_" + hashString(TENANT_ID); + Assert.assertEquals("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=", encryptor.getMasterKey(TENANT_ID)); + } + + // Helper method to prepare a valid GetResponse + private GetResponse prepareMLConfigResponse(String tenantId) throws IOException { + // Compute the masterKeyId based on tenantId + String masterKeyId = MASTER_KEY; + if (tenantId != null) { + masterKeyId = MASTER_KEY + "_" + hashString(tenantId); + } + + // Create the source map with the expected fields + Map sourceMap = Map + .of( + masterKeyId, + "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=", // Valid MASTER_KEY for this tenant + CREATE_TIME_FIELD, + Instant.now().toEpochMilli() + ); + + // Serialize the source map to JSON + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + for (Map.Entry entry : sourceMap.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + BytesReference sourceBytes = BytesReference.bytes(builder); + + // Create the GetResult + GetResult getResult = new GetResult(ML_CONFIG_INDEX, masterKeyId, 1L, 1L, 1L, true, sourceBytes, null, null); + + // Create and return the GetResponse + return new GetResponse(getResult); + } + + // Helper method to prepare a valid IndexResponse + private IndexResponse prepareIndexResponse() { + ShardId shardId = new ShardId(ML_CONFIG_INDEX, "index_uuid", 0); + return new IndexResponse(shardId, MASTER_KEY, 1L, 1L, 1L, true); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java index c187a0bc14..01903050b2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java @@ -27,6 +27,8 @@ import org.opensearch.ml.common.transport.config.MLConfigGetAction; import org.opensearch.ml.common.transport.config.MLConfigGetRequest; import org.opensearch.ml.common.transport.config.MLConfigGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -40,23 +42,32 @@ public class GetConfigTransportAction extends HandledTransportAction actionListener) { MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.fromActionRequest(request); String configId = mlConfigGetRequest.getConfigId(); + String tenantId = mlConfigGetRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + + // In the get request tenantId will be used as a part of SDKClient migration GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(configId); if (configId.equals(MASTER_KEY)) { @@ -74,7 +85,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener = ActionListener.wrap(connector -> { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { - connector.decrypt(connectorAction, (credential) -> encryptor.decrypt(credential)); + // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. + connector.decrypt(connectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setScriptService(scriptService); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 5e766b093d..87c090d0a6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -134,7 +134,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - connector.encrypt(mlEngine::encrypt); + connector.encrypt(mlEngine::encrypt, connector.getTenantId()); log.info("connector created, indexing into the connector system index"); mlIndicesHandler.initMLConnectorIndex(ActionListener.wrap(indexCreated -> { if (!indexCreated) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java index 495ea771f2..ddba443379 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java @@ -127,6 +127,7 @@ protected MLDeployModelNodeResponse nodeOperation(MLDeployModelNodeRequest reque private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNodesRequest MLDeployModelNodesRequest) { MLDeployModelInput deployModelInput = MLDeployModelNodesRequest.getMlDeployModelInput(); + final String tenantId = MLDeployModelNodesRequest.getMlDeployModelInput().getTenantId(); String modelId = deployModelInput.getModelId(); String taskId = deployModelInput.getTaskId(); String coordinatingNodeId = deployModelInput.getCoordinatingNodeId(); @@ -140,12 +141,13 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod String localNodeId = clusterService.localNode().getId(); ActionListener taskDoneListener = ActionListener - .wrap(res -> { log.info("deploy model task done " + taskId); }, ex -> { + .wrap(res -> { log.info("deploy model task done {}", taskId); }, ex -> { logException("Deploy model task failed: " + taskId, ex, log); }); deployModel( modelId, + tenantId, modelContentHash, mlTask.getFunctionName(), localNodeId, @@ -159,6 +161,7 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod .taskId(taskId) .modelId(modelId) .workerNodeId(clusterService.localNode().getId()) + .tenantId(tenantId) .build(); MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput); @@ -179,6 +182,7 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod .modelId(modelId) .workerNodeId(clusterService.localNode().getId()) .error(MLExceptionUtils.getRootCauseMessage(e)) + .tenantId(tenantId) .build(); MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput); @@ -211,6 +215,7 @@ private DiscoveryNode getNodeById(String nodeId) { private void deployModel( String modelId, + String tenantId, String modelContentHash, FunctionName functionName, String localNodeId, @@ -224,6 +229,7 @@ private void deployModel( mlModelManager .deployModel( modelId, + tenantId, modelContentHash, functionName, deployToAllNodes, diff --git a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java index 88a3956753..3679a322b8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java @@ -209,7 +209,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to update ML model: {}", modelId, e); }); - mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> { + mlModelManager.updateModel(modelId, tenantId, updateFields, ActionListener.runBefore(updateModelListener, () -> { mlModelManager.removeAutoDeployModel(modelId); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index c4e265cd04..5402d47456 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -139,7 +139,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener context.restore())); + }), context::restore)); } catch (Exception e) { - log.error("Failed to get ML task " + taskId, e); + log.error("Failed to get ML task {}", taskId, e); actionListener.onFailure(e); } } @@ -167,7 +167,7 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener listener = ActionListener .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { - log.error("Failed to get connector " + model.getConnectorId(), e); + log.error("Failed to get connector {}", model.getConnectorId(), e); actionListener.onFailure(e); }); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { @@ -233,11 +233,12 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener actionListener) { Optional cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name()); - if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) { + if (cancelBatchPredictAction.isEmpty() || cancelBatchPredictAction.get().getRequestBody() == null) { ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT); connector.addAction(connectorAction); } - connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential)); + // multi-tenancy isn't implemented in batch, so setting null as tenant by default + connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, null), null); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setScriptService(scriptService); connectorExecutor.setClusterService(clusterService); @@ -245,7 +246,7 @@ private void executeConnector(Connector connector, MLInput mlInput, ActionListen connectorExecutor.setXContentRegistry(xContentRegistry); connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> { processTaskResponse(taskResponse, actionListener); - }, e -> { actionListener.onFailure(e); })); + }, actionListener::onFailure)); } private void processTaskResponse(MLTaskResponse taskResponse, ActionListener actionListener) { @@ -256,7 +257,7 @@ private void processTaskResponse(MLTaskResponse taskResponse, ActionListener encryptor.decrypt(credential)); + // as we haven't implemented multi-tenancy in batch prediction yet, assigning null as tenantId + connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential, tenantId) -> encryptor.decrypt(credential, null), null); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setScriptService(scriptService); connectorExecutor.setClusterService(clusterService); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 533ae18899..dc110a53d0 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -254,12 +254,16 @@ void initMLConfig() { indexRequest.opType(DocWriteRequest.OpType.CREATE); client.index(indexRequest, ActionListener.wrap(indexResponse -> { log.info("ML configuration initialized successfully"); - encryptor.setMasterKey(masterKey); + // as this method is not being used for multi-tenancy use case, we are setting + // tenant id null by default + encryptor.setMasterKey(null, masterKey); mlConfigInited = true; }, e -> { log.debug("Failed to save ML encryption master key", e); })); } else { final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); - encryptor.setMasterKey(masterKey); + // as this method is not being used for multi-tenancy use case, we are setting + // tenant id null by default + encryptor.setMasterKey(null, masterKey); mlConfigInited = true; log.info("ML configuration already initialized, no action needed"); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 0e4eb23e11..38c8a35e5d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -735,7 +735,8 @@ public MLModel getModelInfo(String modelId) { private MLModelCache getExistingModelCache(String modelId) { MLModelCache modelCache = modelCaches.get(modelId); if (modelCache == null) { - throw new IllegalArgumentException("Model not found in cache"); + return getOrCreateModelCache(modelId); + // throw new IllegalArgumentException("Model not found in cache"); } return modelCache; } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index bd03408700..4319890b34 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -598,7 +598,7 @@ private void indexRemoteModel( String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; Instant now = Instant.now(); if (registerModelInput.getConnector() != null) { - registerModelInput.getConnector().encrypt(mlEngine::encrypt); + registerModelInput.getConnector().encrypt(mlEngine::encrypt, registerModelInput.getTenantId()); } mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(boolResponse -> { @@ -692,7 +692,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; Instant now = Instant.now(); if (registerModelInput.getConnector() != null) { - registerModelInput.getConnector().encrypt(mlEngine::encrypt); + registerModelInput.getConnector().encrypt(mlEngine::encrypt, registerModelInput.getTenantId()); } mlIndicesHandler.initModelIndexIfAbsent(ActionListener.runBefore(ActionListener.wrap(res -> { if (!res) { @@ -2097,7 +2097,7 @@ private void getConnector(String connectorId, String tenantId, ActionListener getRestHandlers( RestMLSearchAgentAction restMLSearchAgentAction = new RestMLSearchAgentAction(mlFeatureEnabledSetting); RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories); RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories); - RestMLGetConfigAction restMLGetConfigAction = new RestMLGetConfigAction(); + RestMLGetConfigAction restMLGetConfigAction = new RestMLGetConfigAction(mlFeatureEnabledSetting); RestMLBatchIngestAction restMLBatchIngestAction = new RestMLBatchIngestAction(); RestMLCancelBatchJobAction restMLCancelBatchJobAction = new RestMLCancelBatchJobAction(); return ImmutableList diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java index 81cb02c597..db7f0c9e91 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONFIG_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -17,6 +18,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.config.MLConfigGetAction; import org.opensearch.ml.common.transport.config.MLConfigGetRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -26,11 +28,14 @@ public class RestMLGetConfigAction extends BaseRestHandler { private static final String ML_GET_CONFIG_ACTION = "ml_get_config_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLGetConfigAction() {} + public RestMLGetConfigAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -58,11 +63,11 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client @VisibleForTesting MLConfigGetRequest getRequest(RestRequest request) throws IOException { String configID = getParameterId(request, PARAMETER_CONFIG_ID); - + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); if (configID.equals(MASTER_KEY)) { throw new IllegalArgumentException("You are not allowed to access this config doc"); } - return new MLConfigGetRequest(configID); + return new MLConfigGetRequest(configID, tenantId); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index e92713be17..7293436b0a 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -115,7 +115,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } }); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - modelManager.getModel(modelId, ActionListener.runBefore(listener, context::restore)); + modelManager + .getModel( + modelId, + getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request), + ActionListener.runBefore(listener, context::restore) + ); } }; } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 784e1f4642..cf359a1c53 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -161,7 +161,7 @@ public void dispatchTask( if (workerNodes == null || workerNodes.length == 0) { if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlModelManager.getModel(modelId, ActionListener.runBefore(ActionListener.wrap(model -> { + mlModelManager.getModel(modelId, request.getTenantId(), ActionListener.runBefore(ActionListener.wrap(model -> { Boolean isHidden = model.getIsHidden(); if (!checkModelAutoDeployEnabled(model)) { final String errorMsg = getErrorMessage( @@ -230,6 +230,7 @@ public void dispatchTask( */ @Override protected void executeTask(MLPredictionTaskRequest request, ActionListener listener) { + final String tenantId = request.getTenantId(); MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType(); Instant now = Instant.now(); String modelId = request.getModelId(); @@ -253,6 +254,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener { @@ -262,7 +264,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener { log.error("Failed to check the maximum BATCH_PREDICTION Task limits", exception); @@ -270,7 +272,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener listener ) { switch (inputDataType) { case SEARCH_QUERY: ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> { MLInput newInput = mlInput.toBuilder().inputDataset(dataSet).build(); - predict(modelId, mlTask, newInput, listener); + predict(modelId, tenantId, mlTask, newInput, listener); }, e -> { log.error("Failed to generate DataFrame from search query", e); handleAsyncMLTaskFailure(mlTask, e); @@ -298,7 +301,7 @@ private void executePredictionByInputDataType( case TEXT_DOCS: default: String threadPoolName = getPredictThreadPool(functionName); - threadPool.executor(threadPoolName).execute(() -> { predict(modelId, mlTask, mlInput, listener); }); + threadPool.executor(threadPoolName).execute(() -> { predict(modelId, tenantId, mlTask, mlInput, listener); }); break; } } @@ -314,7 +317,7 @@ private String getPredictThreadPool(FunctionName functionName) { return functionName == FunctionName.REMOTE ? REMOTE_PREDICT_THREAD_POOL : PREDICT_THREAD_POOL; } - private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener listener) { + private void predict(String modelId, String tenantId, MLTask mlTask, MLInput mlInput, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache ActionName actionName = getActionNameFromInput(mlInput); @@ -343,21 +346,23 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe .lastUpdateTime(now) .state(MLTaskState.RUNNING) .workerNodes(Arrays.asList(clusterService.localNode().getId())) + .tenantId(tenantId) .build(); - mlModelManager.deployModel(modelId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { - runPredict(modelId, mlTask, mlInput, functionName, actionName, internalListener); + mlModelManager.deployModel(modelId, tenantId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { + runPredict(modelId, tenantId, mlTask, mlInput, functionName, actionName, internalListener); }, e -> { - log.error("Failed to auto deploy model " + modelId, e); + log.error("Failed to auto deploy model {}", modelId, e); internalListener.onFailure(e); })); return; } - runPredict(modelId, mlTask, mlInput, functionName, actionName, internalListener); + runPredict(modelId, tenantId, mlTask, mlInput, functionName, actionName, internalListener); } private void runPredict( String modelId, + String tenantId, MLTask mlTask, MLInput mlInput, FunctionName algorithm, @@ -482,7 +487,7 @@ private void runPredict( } // run predict if (mlTaskManager.contains(mlTask.getTaskId())) { - mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.getTenantId(), mlTask.isAsync()); + mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), tenantId, mlTask.isAsync()); } MLOutput output = mlEngine.predict(mlInput, mlModel); if (output instanceof MLPredictionOutput) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java index afa4153a74..5175783300 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java @@ -43,6 +43,7 @@ import org.opensearch.ml.common.MLConfig; import org.opensearch.ml.common.transport.config.MLConfigGetRequest; import org.opensearch.ml.common.transport.config.MLConfigGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -67,6 +68,9 @@ public class GetConfigTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -79,7 +83,9 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlConfigGetRequest = MLConfigGetRequest.builder().configId("test_id").build(); - getConfigTransportAction = spy(new GetConfigTransportAction(transportService, actionFilters, client, xContentRegistry)); + getConfigTransportAction = spy( + new GetConfigTransportAction(transportService, actionFilters, client, xContentRegistry, mlFeatureEnabledSetting) + ); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -128,7 +134,7 @@ public void testDoExecute_Failure_Context_Exception() { String configId = "test-config-id"; ActionListener actionListener = mock(ActionListener.class); - MLConfigGetRequest getRequest = new MLConfigGetRequest(configId); + MLConfigGetRequest getRequest = new MLConfigGetRequest(configId, null); Task task = mock(Task.class); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenThrow(new RuntimeException()); @@ -149,7 +155,7 @@ public void testDoExecute_Success() throws IOException { String configID = "config_id"; GetResponse getResponse = prepareMLConfig(configID); ActionListener actionListener = mock(ActionListener.class); - MLConfigGetRequest request = new MLConfigGetRequest(configID); + MLConfigGetRequest request = new MLConfigGetRequest(configID, null); Task task = mock(Task.class); doAnswer(invocation -> { @@ -165,14 +171,14 @@ public void testDoExecute_Success() throws IOException { @Test public void testDoExecute_Success_ForNewFields() throws IOException { String configID = "config_id"; - MLConfig mlConfig = new MLConfig(null, "olly_agent", null, new Configuration("agent_id"), Instant.EPOCH, null, Instant.EPOCH); + MLConfig mlConfig = new MLConfig(null, "olly_agent", null, new Configuration("agent_id"), Instant.EPOCH, null, Instant.EPOCH, null); XContentBuilder content = mlConfig.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", configID, 111l, 111l, 111l, true, bytesReference, null, null); GetResponse getResponse = new GetResponse(getResult); ActionListener actionListener = mock(ActionListener.class); - MLConfigGetRequest request = new MLConfigGetRequest(configID); + MLConfigGetRequest request = new MLConfigGetRequest(configID, null); Task task = mock(Task.class); doAnswer(invocation -> { @@ -187,13 +193,12 @@ public void testDoExecute_Success_ForNewFields() throws IOException { public GetResponse prepareMLConfig(String configID) throws IOException { - MLConfig mlConfig = new MLConfig("olly_agent", null, new Configuration("agent_id"), null, Instant.EPOCH, Instant.EPOCH, null); + MLConfig mlConfig = new MLConfig("olly_agent", null, new Configuration("agent_id"), null, Instant.EPOCH, Instant.EPOCH, null, null); XContentBuilder content = mlConfig.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", configID, 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); - return getResponse; + return new GetResponse(getResult); } @Test @@ -201,7 +206,7 @@ public void testDoExecute_Rejected_MASTER_KEY() throws IOException { String configID = MASTER_KEY; GetResponse getResponse = prepareMLConfig(configID); ActionListener actionListener = mock(ActionListener.class); - MLConfigGetRequest request = new MLConfigGetRequest(configID); + MLConfigGetRequest request = new MLConfigGetRequest(configID, null); Task task = mock(Task.class); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index 2136b9ec56..767d71414f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -171,7 +171,7 @@ public void setup() throws IOException { SearchResponse.Clusters.EMPTY ); - Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + Encryptor encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); updateConnectorTransportAction = new UpdateConnectorTransportAction( diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index cc4071714d..8974999b05 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -154,7 +154,7 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); modelHelper = new ModelHelper(mlEngine); when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java index a12cce0b89..ff13f36f7f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java @@ -298,7 +298,7 @@ public void testCleanUpLocalCache_ExpiredMLTask_Register() { when(mlTaskManager.getMLTaskCache(taskId)).thenReturn(taskCache); action.cleanUpLocalCache(runningDeployModelTasks); verify(mlTaskManager, times(1)).updateMLTask(anyString(), any(), any(), anyLong(), anyBoolean()); - verify(mlModelManager, never()).updateModel(anyString(), (Boolean) any(), any()); + verify(mlModelManager, never()).updateModel(anyString(), any(), (Boolean) any(), any()); } @Test @@ -340,7 +340,7 @@ private void testCleanUpLocalCache_ExpiredMLTask_DeployStatus(MLModelState model action.cleanUpLocalCache(runningDeployModelTasks); verify(mlTaskManager, times(1)).updateMLTask(anyString(), any(), any(), anyLong(), anyBoolean()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); - verify(mlModelManager, never()).updateModel(eq(modelId), eq(false), argumentCaptor.capture()); + verify(mlModelManager, never()).updateModel(eq(modelId), eq(null), eq(false), argumentCaptor.capture()); } private MLSyncUpInput prepareRequest() { diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 696faa7432..3515896fc8 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -64,6 +64,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; @@ -124,7 +125,7 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); - encryptor = spy(new EncryptorImpl(null)); + encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); @@ -161,9 +162,9 @@ public void testInitMlConfig_MasterKeyNotExist() { }).when(client).index(any(), any()); syncUpCron.initMLConfig(); - Assert.assertNotNull(encryptor.encrypt("test")); + Assert.assertNotNull(encryptor.encrypt("test", null)); syncUpCron.initMLConfig(); - verify(encryptor, times(1)).setMasterKey(any()); + verify(encryptor, times(1)).setMasterKey(any(), any()); } public void testInitMlConfig_MasterKeyExists() { @@ -179,9 +180,9 @@ public void testInitMlConfig_MasterKeyExists() { }).when(client).get(any(), any()); syncUpCron.initMLConfig(); - Assert.assertNotNull(encryptor.encrypt("test")); + Assert.assertNotNull(encryptor.encrypt("test", null)); syncUpCron.initMLConfig(); - verify(encryptor, times(1)).setMasterKey(any()); + verify(encryptor, times(1)).setMasterKey(any(), any()); } public void testRun_NoMLModelIndex() { @@ -304,7 +305,8 @@ public void testRefreshModelState_ResetAsDeployFailed() { Map> deployingModels = new HashMap<>(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYED, 2, null, Instant.now().toEpochMilli())); + actionListener + .onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYED, 2, null, Instant.now().toEpochMilli())); return null; }).when(client).search(any(), any()); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); @@ -327,7 +329,8 @@ public void testRefreshModelState_ResetAsPartiallyDeployed() { Map> deployingModels = new HashMap<>(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYED, 2, 0, Instant.now().toEpochMilli())); + actionListener + .onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYED, 2, 0, Instant.now().toEpochMilli())); return null; }).when(client).search(any(), any()); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); @@ -351,7 +354,7 @@ public void testRefreshModelState_ResetCurrentWorkerNodeCountForPartiallyDeploye doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener - .onResponse(createSearchModelResponse("modelId", MLModelState.PARTIALLY_DEPLOYED, 3, 2, Instant.now().toEpochMilli())); + .onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYED, 2, 0, Instant.now().toEpochMilli())); return null; }).when(client).search(any(), any()); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); @@ -375,7 +378,8 @@ public void testRefreshModelState_ResetAsDeploying() { deployingModels.put("modelId", ImmutableSet.of("node2")); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOY_FAILED, 2, 0, Instant.now().toEpochMilli())); + actionListener + .onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYED, 2, 0, Instant.now().toEpochMilli())); return null; }).when(client).search(any(), any()); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); @@ -398,7 +402,10 @@ public void testRefreshModelState_NotResetState_DeployingModelTaskRunning() { deployingModels.put("modelId", ImmutableSet.of("node2")); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli())); + actionListener + .onResponse( + createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli()) + ); return null; }).when(client).search(any(), any()); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); @@ -411,7 +418,10 @@ public void testRefreshModelState_NotResetState_DeployingInGraceTime() { Map> deployingModels = new HashMap<>(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli())); + actionListener + .onResponse( + createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli()) + ); return null; }).when(client).search(any(), any()); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); @@ -454,6 +464,7 @@ private void mockSyncUp_GatherRunningTasks_Failure() { private SearchResponse createSearchModelResponse( String modelId, + String tenantId, MLModelState state, Integer planningWorkerNodeCount, Integer currentWorkerNodeCount, @@ -461,6 +472,7 @@ private SearchResponse createSearchModelResponse( ) throws IOException { XContentBuilder content = TestHelper.builder(); content.startObject(); + content.field(CommonValue.TENANT_ID_FIELD, tenantId); content.field(MLModel.MODEL_STATE_FIELD, state); content.field(MLModel.ALGORITHM_FIELD, FunctionName.KMEANS); content.field(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, planningWorkerNodeCount); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java index 4b4e6ace27..6a04915d35 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java @@ -132,12 +132,6 @@ public void testModelState_DuplicateError() { cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true); } - public void testPredictor_NotFoundException() { - expectedEx.expect(IllegalArgumentException.class); - expectedEx.expectMessage("Model not found in cache"); - cacheHelper.setPredictor("modelId1", predictor); - } - public void testPredictor() { cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true); assertNull(cacheHelper.getPredictor(modelId)); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 767768ac8f..b807363d4d 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -202,7 +202,7 @@ public class MLModelManagerTests extends OpenSearchTestCase { public void setup() throws URISyntaxException { String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl(masterKey); + encryptor = new EncryptorImpl(null, masterKey); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE.getKey(), 10).build(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java index 571bbcbeab..040da94270 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java @@ -11,18 +11,24 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONFIG_ID; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -30,6 +36,7 @@ import org.opensearch.ml.common.transport.config.MLConfigGetAction; import org.opensearch.ml.common.transport.config.MLConfigGetRequest; import org.opensearch.ml.common.transport.config.MLConfigGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -42,18 +49,30 @@ public class RestMLGetConfigActionTests extends OpenSearchTestCase { private RestMLGetConfigAction restMLGetConfigAction; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + NodeClient client; private ThreadPool threadPool; @Mock RestChannel channel; + Settings settings; + @Mock + private ClusterService clusterService; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setup() { - restMLGetConfigAction = new RestMLGetConfigAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLGetConfigAction = new RestMLGetConfigAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -72,7 +91,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetConfigAction mlGetConfigAction = new RestMLGetConfigAction(); + RestMLGetConfigAction mlGetConfigAction = new RestMLGetConfigAction(mlFeatureEnabledSetting); assertNotNull(mlGetConfigAction); } @@ -111,7 +130,6 @@ public void test_InvalidConfigID() throws Exception { private RestRequest getRestRequest(String configID) { Map params = new HashMap<>(); params.put(PARAMETER_CONFIG_ID, configID); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); - return request; + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index c90f765ed0..955cb07469 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -24,6 +24,7 @@ import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -90,17 +91,20 @@ public void tearDown() throws Exception { client.close(); } + @Test public void testConstructor() { RestMLPredictionAction mlPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting); assertNotNull(mlPredictionAction); } + @Test public void testGetName() { String actionName = restMLPredictionAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); assertEquals("ml_prediction_action", actionName); } + @Test public void testRoutes() { List routes = restMLPredictionAction.routes(); assertNotNull(routes); @@ -110,6 +114,7 @@ public void testRoutes() { assertEquals("/_plugins/_ml/_predict/{algorithm}/{model_id}", route.getPath()); } + @Test public void testRoutes_Batch() { List routes = restMLPredictionAction.routes(); assertNotNull(routes); @@ -119,6 +124,7 @@ public void testRoutes_Batch() { assertEquals("/_plugins/_ml/models/{model_id}/_batch_predict", route.getPath()); } + @Test public void testGetRequest() throws IOException { RestRequest request = getRestRequest_PredictModel(); MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request); @@ -127,6 +133,7 @@ public void testGetRequest() throws IOException { verifyParsedKMeansMLInput(mlInput); } + @Test public void testGetRequest_RemoteInferenceDisabled() throws IOException { thrown.expect(IllegalStateException.class); thrown.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); @@ -136,6 +143,7 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException { MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request); } + @Test public void testGetRequest_LocalModelInferenceDisabled() throws IOException { thrown.expect(IllegalStateException.class); thrown.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG); @@ -146,6 +154,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException { .getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request); } + @Test public void testPrepareRequest() throws Exception { RestRequest request = getRestRequest_PredictModel(); restMLPredictionAction.handleRequest(request, channel, client); @@ -155,6 +164,7 @@ public void testPrepareRequest() throws Exception { verifyParsedKMeansMLInput(mlInput); } + @Test public void testPrepareBatchRequest() throws Exception { RestRequest request = getBatchRestRequest(); when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true); @@ -165,6 +175,7 @@ public void testPrepareBatchRequest() throws Exception { verifyParsedBatchMLInput(mlInput); } + @Test public void testPrepareBatchRequest_FeatureFlagDisabled() throws Exception { thrown.expect(IllegalStateException.class); thrown @@ -177,6 +188,7 @@ public void testPrepareBatchRequest_FeatureFlagDisabled() throws Exception { restMLPredictionAction.handleRequest(request, channel, client); } + @Test public void testPrepareBatchRequest_WrongActionType() throws Exception { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("Wrong Action Type"); @@ -213,6 +225,26 @@ public void testPrepareRequest_EmptyAlgorithm() throws Exception { verifyParsedKMeansMLInput(mlInput); } + @Test + public void testGetRequest_InvalidActionType() throws IOException { + // Test with an invalid action type + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Wrong Action Type of models"); + + RestRequest request = getBatchRestRequest_WrongActionType(); + restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), request); + } + + @Test + public void testGetRequest_UnsupportedAlgorithm() throws IOException { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Wrong function name"); + + // Create a RestRequest with an unsupported algorithm + RestRequest request = getRestRequest_PredictModel(); + restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", request); + } + private RestRequest getRestRequest_PredictModel() { RestRequest request = getKMeansRestRequest(); request.params().put(PARAMETER_MODEL_ID, "model_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index a4e7a87a82..ff6be115b8 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -92,7 +92,7 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index c6f8a3162c..0254cd619f 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -144,7 +144,7 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); @@ -259,10 +259,10 @@ public void testExecuteTask_OnLocalNode_QueryInput() { public void testExecuteTask_OnLocalNode_RemoteModelAutoDeploy() { setupMocks(true, false, false, false); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any()); + }).when(mlModelManager).getModel(any(), any(), any()); when(mlModelManager.addModelToAutoDeployCache("111", mlModel)).thenReturn(mlModel); taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); verify(client).execute(any(), any(), any()); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index e381a6541d..5f2ac53603 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -101,7 +101,7 @@ public class MLTrainAndPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().build(); MockitoAnnotations.openMocks(this); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index 943bd5740d..e6938e79a6 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -111,7 +111,7 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); mlEngine = new MLEngine(Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT);