Skip to content

Commit 6e94702

Browse files
committed
adding more unit tests
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 4948ea7 commit 6e94702

File tree

7 files changed

+386
-26
lines changed

7 files changed

+386
-26
lines changed

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

+20
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
import java.nio.ByteBuffer;
99
import java.nio.charset.StandardCharsets;
1010
import java.security.AccessController;
11+
import java.security.MessageDigest;
12+
import java.security.NoSuchAlgorithmException;
1113
import java.security.PrivilegedActionException;
1214
import java.security.PrivilegedExceptionAction;
1315
import java.util.ArrayList;
16+
import java.util.Base64;
1417
import java.util.HashMap;
1518
import java.util.HashSet;
1619
import java.util.List;
@@ -371,4 +374,21 @@ public static void validateSchema(String schemaString, String instanceString) {
371374
throw new OpenSearchParseException("Schema validation failed: " + e.getMessage(), e);
372375
}
373376
}
377+
378+
public static String hashString(String input) {
379+
try {
380+
// Create a MessageDigest instance for SHA-256
381+
MessageDigest digest = MessageDigest.getInstance("SHA-256");
382+
383+
// Perform the hashing and get the byte array
384+
byte[] hashBytes = digest.digest(input.getBytes());
385+
386+
// Convert the byte array to a Base64 encoded string
387+
return Base64.getUrlEncoder().encodeToString(hashBytes);
388+
389+
} catch (NoSuchAlgorithmException e) {
390+
throw new RuntimeException("Error: Unable to compute hash", e);
391+
}
392+
}
393+
374394
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common;
7+
8+
import java.io.IOException;
9+
import java.time.Instant;
10+
import java.util.Collections;
11+
12+
import org.junit.Assert;
13+
import org.junit.Rule;
14+
import org.junit.Test;
15+
import org.junit.rules.ExpectedException;
16+
import org.opensearch.Version;
17+
import org.opensearch.common.io.stream.BytesStreamOutput;
18+
import org.opensearch.common.settings.Settings;
19+
import org.opensearch.common.xcontent.XContentType;
20+
import org.opensearch.core.common.io.stream.StreamInput;
21+
import org.opensearch.core.xcontent.NamedXContentRegistry;
22+
import org.opensearch.core.xcontent.ToXContent;
23+
import org.opensearch.core.xcontent.XContentBuilder;
24+
import org.opensearch.core.xcontent.XContentParser;
25+
import org.opensearch.search.SearchModule;
26+
27+
public class MLConfigTest {
28+
29+
@Rule
30+
public ExpectedException exceptionRule = ExpectedException.none();
31+
32+
@Test
33+
public void toXContent_Minimal() throws IOException {
34+
MLConfig config = MLConfig.builder().type("test_type").build();
35+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
36+
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
37+
String content = TestHelper.xContentBuilderToString(builder);
38+
Assert.assertEquals("{\"type\":\"test_type\"}", content);
39+
}
40+
41+
@Test
42+
public void toXContent_Full() throws IOException {
43+
Instant now = Instant.now();
44+
Configuration configuration = Configuration.builder().build();
45+
MLConfig config = MLConfig
46+
.builder()
47+
.type("test_type")
48+
.configType("test_config_type")
49+
.configuration(configuration)
50+
.mlConfiguration(configuration)
51+
.createTime(now)
52+
.lastUpdateTime(now)
53+
.lastUpdatedTime(now)
54+
.tenantId("test_tenant")
55+
.build();
56+
57+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
58+
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
59+
String content = TestHelper.xContentBuilderToString(builder);
60+
Assert
61+
.assertTrue(
62+
content.contains("\"type\":\"test_config_type\"")
63+
&& content.contains("\"configuration\":")
64+
&& content.contains("\"create_time\":" + now.toEpochMilli())
65+
&& content.contains("\"last_update_time\":" + now.toEpochMilli())
66+
&& content.contains("\"tenant_id\":\"test_tenant\"")
67+
);
68+
}
69+
70+
@Test
71+
public void parse_Minimal() throws IOException {
72+
String jsonStr = "{\"type\":\"test_type\"}";
73+
XContentParser parser = XContentType.JSON
74+
.xContent()
75+
.createParser(
76+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
77+
null,
78+
jsonStr
79+
);
80+
parser.nextToken();
81+
MLConfig config = MLConfig.parse(parser);
82+
Assert.assertEquals("test_type", config.getType());
83+
Assert.assertNull(config.getConfigType());
84+
Assert.assertNull(config.getConfiguration());
85+
Assert.assertNull(config.getMlConfiguration());
86+
Assert.assertNull(config.getCreateTime());
87+
Assert.assertNull(config.getLastUpdateTime());
88+
Assert.assertNull(config.getLastUpdatedTime());
89+
Assert.assertNull(config.getTenantId());
90+
}
91+
92+
@Test
93+
public void parse_Full() throws IOException {
94+
String jsonStr = "{\"type\":\"test_type\",\"config_type\":\"test_config_type\","
95+
+ "\"configuration\":{},\"ml_configuration\":{},\"create_time\":1672531200000,"
96+
+ "\"last_update_time\":1672534800000,\"last_updated_time\":1672538400000,\"tenant_id\":\"test_tenant\"}";
97+
XContentParser parser = XContentType.JSON
98+
.xContent()
99+
.createParser(
100+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
101+
null,
102+
jsonStr
103+
);
104+
parser.nextToken();
105+
MLConfig config = MLConfig.parse(parser);
106+
Assert.assertEquals("test_type", config.getType());
107+
Assert.assertEquals("test_config_type", config.getConfigType());
108+
Assert.assertNotNull(config.getConfiguration());
109+
Assert.assertNotNull(config.getMlConfiguration());
110+
Assert.assertEquals(Instant.ofEpochMilli(1672531200000L), config.getCreateTime());
111+
Assert.assertEquals(Instant.ofEpochMilli(1672534800000L), config.getLastUpdateTime());
112+
Assert.assertEquals(Instant.ofEpochMilli(1672538400000L), config.getLastUpdatedTime());
113+
Assert.assertEquals("test_tenant", config.getTenantId());
114+
}
115+
116+
@Test
117+
public void writeToAndReadFrom() throws IOException {
118+
Instant now = Instant.now();
119+
Configuration configuration = Configuration.builder().build();
120+
MLConfig originalConfig = MLConfig
121+
.builder()
122+
.type("test_type")
123+
.configType("test_config_type")
124+
.configuration(configuration)
125+
.mlConfiguration(configuration)
126+
.createTime(now)
127+
.lastUpdateTime(now)
128+
.lastUpdatedTime(now)
129+
.tenantId("test_tenant")
130+
.build();
131+
132+
BytesStreamOutput output = new BytesStreamOutput();
133+
originalConfig.writeTo(output);
134+
135+
MLConfig deserializedConfig = new MLConfig(output.bytes().streamInput());
136+
Assert.assertEquals("test_type", deserializedConfig.getType());
137+
Assert.assertEquals("test_config_type", deserializedConfig.getConfigType());
138+
Assert.assertNotNull(deserializedConfig.getConfiguration());
139+
Assert.assertNotNull(deserializedConfig.getMlConfiguration());
140+
Assert.assertEquals(now, deserializedConfig.getCreateTime());
141+
Assert.assertEquals(now, deserializedConfig.getLastUpdateTime());
142+
Assert.assertEquals(now, deserializedConfig.getLastUpdatedTime());
143+
Assert.assertEquals("test_tenant", deserializedConfig.getTenantId());
144+
}
145+
146+
@Test
147+
public void writeToAndReadFrom_Minimal() throws IOException {
148+
MLConfig originalConfig = MLConfig.builder().type("test_type").build();
149+
150+
BytesStreamOutput output = new BytesStreamOutput();
151+
originalConfig.writeTo(output);
152+
153+
MLConfig deserializedConfig = new MLConfig(output.bytes().streamInput());
154+
Assert.assertEquals("test_type", deserializedConfig.getType());
155+
Assert.assertNull(deserializedConfig.getConfigType());
156+
Assert.assertNull(deserializedConfig.getConfiguration());
157+
Assert.assertNull(deserializedConfig.getMlConfiguration());
158+
Assert.assertNull(deserializedConfig.getCreateTime());
159+
Assert.assertNull(deserializedConfig.getLastUpdateTime());
160+
Assert.assertNull(deserializedConfig.getLastUpdatedTime());
161+
Assert.assertNull(deserializedConfig.getTenantId());
162+
}
163+
164+
@Test
165+
public void crossVersionSerialization_NoTenantId() throws IOException {
166+
// Simulate an older version (before VERSION_2_19_0)
167+
Version oldVersion = Version.V_2_18_0;
168+
169+
// Create an MLConfig instance with tenantId set
170+
MLConfig originalConfig = MLConfig.builder().type("test_type").tenantId("test_tenant").build();
171+
172+
// Serialize using the older version
173+
BytesStreamOutput output = new BytesStreamOutput();
174+
output.setVersion(oldVersion);
175+
originalConfig.writeTo(output);
176+
177+
// Deserialize and verify tenantId is not present
178+
StreamInput input = output.bytes().streamInput();
179+
input.setVersion(oldVersion);
180+
MLConfig deserializedConfig = new MLConfig(input);
181+
182+
Assert.assertEquals("test_type", deserializedConfig.getType());
183+
Assert.assertNull(deserializedConfig.getTenantId());
184+
}
185+
186+
@Test
187+
public void crossVersionSerialization_WithTenantId() throws IOException {
188+
// Simulate a newer version (on or after VERSION_2_19_0)
189+
Version newVersion = Version.V_2_19_0;
190+
191+
// Create an MLConfig instance with tenantId set
192+
MLConfig originalConfig = MLConfig.builder().type("test_type").tenantId("test_tenant").build();
193+
194+
// Serialize using the newer version
195+
BytesStreamOutput output = new BytesStreamOutput();
196+
output.setVersion(newVersion);
197+
originalConfig.writeTo(output);
198+
199+
// Deserialize and verify tenantId is present
200+
StreamInput input = output.bytes().streamInput();
201+
input.setVersion(newVersion);
202+
MLConfig deserializedConfig = new MLConfig(input);
203+
204+
Assert.assertEquals("test_type", deserializedConfig.getType());
205+
Assert.assertEquals("test_tenant", deserializedConfig.getTenantId());
206+
}
207+
208+
}

common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java

+67
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
package org.opensearch.ml.common.transport.config;
66

77
import static org.junit.Assert.assertEquals;
8+
import static org.junit.Assert.assertNull;
89
import static org.opensearch.action.ValidateActions.addValidationError;
910

1011
import java.io.IOException;
1112
import java.io.UncheckedIOException;
1213

1314
import org.junit.Test;
15+
import org.opensearch.Version;
1416
import org.opensearch.action.ActionRequest;
1517
import org.opensearch.action.ActionRequestValidationException;
1618
import org.opensearch.common.io.stream.BytesStreamOutput;
19+
import org.opensearch.core.common.io.stream.StreamInput;
1720
import org.opensearch.core.common.io.stream.StreamOutput;
1821

1922
public class MLConfigGetRequestTest {
@@ -103,4 +106,68 @@ public void writeTo(StreamOutput out) throws IOException {
103106
};
104107
mlConfigGetRequest.fromActionRequest(actionRequest);
105108
}
109+
110+
@Test
111+
public void writeTo_WithTenantId() throws IOException {
112+
configId = "test-with-tenant";
113+
tenantId = "test_tenant";
114+
115+
MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId);
116+
BytesStreamOutput output = new BytesStreamOutput();
117+
mlConfigGetRequest.writeTo(output);
118+
119+
MLConfigGetRequest deserializedRequest = new MLConfigGetRequest(output.bytes().streamInput());
120+
121+
assertEquals(mlConfigGetRequest.getConfigId(), deserializedRequest.getConfigId());
122+
assertEquals(mlConfigGetRequest.getTenantId(), deserializedRequest.getTenantId());
123+
assertEquals(tenantId, deserializedRequest.getTenantId());
124+
}
125+
126+
@Test
127+
public void crossVersionSerialization_WithoutTenantIdForOldVersion() throws IOException {
128+
configId = "test-no-tenant";
129+
tenantId = "test_tenant";
130+
131+
// Simulate an older version (before VERSION_2_19_0)
132+
Version oldVersion = Version.V_2_18_0;
133+
134+
MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId);
135+
BytesStreamOutput output = new BytesStreamOutput();
136+
output.setVersion(oldVersion); // Set the version for the output
137+
mlConfigGetRequest.writeTo(output);
138+
139+
// Set the version for the input to match the older version
140+
StreamInput input = output.bytes().streamInput();
141+
input.setVersion(oldVersion); // Important to match the output version
142+
143+
MLConfigGetRequest deserializedRequest = new MLConfigGetRequest(input);
144+
145+
// Validate that the configId is correctly deserialized and tenantId is null
146+
assertEquals(configId, deserializedRequest.getConfigId());
147+
assertNull(deserializedRequest.getTenantId()); // tenantId should not be present for old versions
148+
}
149+
150+
@Test
151+
public void fromActionRequest_WithTenantId() throws IOException {
152+
configId = "test-with-tenant";
153+
tenantId = "test_tenant";
154+
155+
MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId, tenantId);
156+
157+
ActionRequest actionRequest = new ActionRequest() {
158+
@Override
159+
public ActionRequestValidationException validate() {
160+
return null;
161+
}
162+
163+
@Override
164+
public void writeTo(StreamOutput out) throws IOException {
165+
mlConfigGetRequest.writeTo(out);
166+
}
167+
};
168+
MLConfigGetRequest deserializedRequest = mlConfigGetRequest.fromActionRequest(actionRequest);
169+
170+
assertEquals(configId, deserializedRequest.getConfigId());
171+
assertEquals(tenantId, deserializedRequest.getTenantId());
172+
}
106173
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java

+1-19
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
1111
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
1212
import static org.opensearch.ml.common.MLConfig.CREATE_TIME_FIELD;
13+
import static org.opensearch.ml.common.utils.StringUtils.hashString;
1314

1415
import java.nio.charset.StandardCharsets;
15-
import java.security.MessageDigest;
16-
import java.security.NoSuchAlgorithmException;
1716
import java.security.SecureRandom;
1817
import java.time.Instant;
1918
import java.util.Base64;
@@ -226,21 +225,4 @@ private void initMasterKey(String tenantId) {
226225
throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR);
227226
}
228227
}
229-
230-
private String hashString(String input) {
231-
try {
232-
// Create a MessageDigest instance for SHA-256
233-
MessageDigest digest = MessageDigest.getInstance("SHA-256");
234-
235-
// Perform the hashing and get the byte array
236-
byte[] hashBytes = digest.digest(input.getBytes());
237-
238-
// Convert the byte array to a Base64 encoded string
239-
return Base64.getUrlEncoder().encodeToString(hashBytes);
240-
241-
} catch (NoSuchAlgorithmException e) {
242-
throw new RuntimeException("Error: Unable to compute hash", e);
243-
}
244-
}
245-
246228
}

0 commit comments

Comments
 (0)