Skip to content

Commit a47db3a

Browse files
zane-neoylwu-amzn
andauthored
[Backport to main]fix security IT failure caused by weak password (opensearch-project#951) (opensearch-project#1257)
* fix security IT failure caused by weak password (opensearch-project#951) Signed-off-by: Yaliang Wu <ylwu@amazon.com> * Fix pre-trained model metadata parse exception Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com> Signed-off-by: zane-neo <zaniu@amazon.com> Co-authored-by: Yaliang Wu <ylwu@amazon.com>
1 parent f93e789 commit a47db3a

File tree

4 files changed

+39
-65
lines changed

4 files changed

+39
-65
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java

+4-10
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,10 @@ public boolean isModelAllowed(MLRegisterModelInput registerModelInput, List mode
150150
String version = registerModelInput.getVersion();
151151
MLModelFormat modelFormat = registerModelInput.getModelFormat();
152152
for (Object meta: modelMetaList) {
153-
Map<String, Object> metaMap = (Map<String, Object>) meta;
154-
String name = (String) metaMap.get("name");
155-
Map<String, Object> versions = (Map<String, Object>) metaMap.get("versions");
156-
Object versionObj = versions.get(version);
157-
if (versionObj == null) return false;
158-
Map<String, Object> versionMap = (Map<String, Object>) versionObj;
159-
Object formatObj = versionMap.get("format");
160-
if (formatObj == null) return false;
161-
List<String> formats = (List<String>) formatObj;
162-
if (name.equals(modelName) && versions.containsKey(version.toLowerCase(Locale.ROOT)) && formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) {
153+
String name = (String) ((Map<String, Object>)meta).get("name");
154+
List<String> versions = (List) ((Map<String, Object>)meta).get("version");
155+
List<String> formats = (List) ((Map<String, Object>)meta).get("format");
156+
if (name.equals(modelName) && versions.contains(version.toLowerCase(Locale.ROOT)) && formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) {
163157
return true;
164158
}
165159
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ public void testDownloadPrebuiltModelMetaList() throws PrivilegedActionException
184184
.modelNodeIds(new String[]{"node_id1"})
185185
.build();
186186
List modelMetaList = modelHelper.downloadPrebuiltModelMetaList(taskId, registerModelInput);
187-
assertEquals("huggingface/sentence-transformers/all-MiniLM-L12-v2", ((Map<String, String>)modelMetaList.get(0)).get("name"));
187+
assertEquals("huggingface/sentence-transformers/all-distilroberta-v1", ((Map<String, String>)modelMetaList.get(0)).get("name"));
188188
}
189189

190190
@Test

plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java

+19-31
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public class MLModelGroupRestIT extends MLCommonsRestTestCase {
5858
public ExpectedException exceptionRule = ExpectedException.none();
5959

6060
private String modelGroupId;
61+
private String password = "IntegTest@MLModelGroupRestIT123";
6162

6263
@Before
6364
public void setup() throws IOException {
@@ -77,56 +78,43 @@ public void setup() throws IOException {
7778
}
7879
createSearchRole(indexSearchAccessRole, "*");
7980

80-
createUser(mlNoAccessUser, mlNoAccessUser, ImmutableList.of(opensearchBackendRole));
81-
mlNoAccessClient = new SecureRestClientBuilder(
82-
getClusterHosts().toArray(new HttpHost[0]),
83-
isHttps(),
84-
mlNoAccessUser,
85-
mlNoAccessUser
86-
).setSocketTimeout(60000).build();
81+
createUser(mlNoAccessUser, password, ImmutableList.of(opensearchBackendRole));
82+
mlNoAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlNoAccessUser, password)
83+
.setSocketTimeout(60000)
84+
.build();
8785

88-
createUser(mlReadOnlyUser, mlReadOnlyUser, ImmutableList.of(opensearchBackendRole));
89-
mlReadOnlyClient = new SecureRestClientBuilder(
90-
getClusterHosts().toArray(new HttpHost[0]),
91-
isHttps(),
92-
mlReadOnlyUser,
93-
mlReadOnlyUser
94-
).setSocketTimeout(60000).build();
86+
createUser(mlReadOnlyUser, password, ImmutableList.of(opensearchBackendRole));
87+
mlReadOnlyClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlReadOnlyUser, password)
88+
.setSocketTimeout(60000)
89+
.build();
9590

96-
createUser(mlFullAccessNoIndexAccessUser, mlFullAccessNoIndexAccessUser, ImmutableList.of(opensearchBackendRole));
91+
createUser(mlFullAccessNoIndexAccessUser, password, ImmutableList.of(opensearchBackendRole));
9792
mlFullAccessNoIndexAccessClient = new SecureRestClientBuilder(
9893
getClusterHosts().toArray(new HttpHost[0]),
9994
isHttps(),
10095
mlFullAccessNoIndexAccessUser,
101-
mlFullAccessNoIndexAccessUser
96+
password
10297
).setSocketTimeout(60000).build();
10398

104-
createUser(mlFullAccessUser, mlFullAccessUser, ImmutableList.of(opensearchBackendRole));
105-
mlFullAccessClient = new SecureRestClientBuilder(
106-
getClusterHosts().toArray(new HttpHost[0]),
107-
isHttps(),
108-
mlFullAccessUser,
109-
mlFullAccessUser
110-
).setSocketTimeout(60000).build();
99+
createUser(mlFullAccessUser, password, ImmutableList.of(opensearchBackendRole));
100+
mlFullAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlFullAccessUser, password)
101+
.setSocketTimeout(60000)
102+
.build();
111103

112-
createUser(mlNonAdminFullAccessWithoutBackendRoleUser, mlNonAdminFullAccessWithoutBackendRoleUser, ImmutableList.of());
104+
createUser(mlNonAdminFullAccessWithoutBackendRoleUser, password, ImmutableList.of());
113105
mlNonAdminFullAccessWithoutBackendRoleClient = new SecureRestClientBuilder(
114106
getClusterHosts().toArray(new HttpHost[0]),
115107
isHttps(),
116108
mlNonAdminFullAccessWithoutBackendRoleUser,
117-
mlNonAdminFullAccessWithoutBackendRoleUser
109+
password
118110
).setSocketTimeout(60000).build();
119111

120-
createUser(
121-
mlNonOwnerFullAccessWithBackendRoleUser,
122-
mlNonOwnerFullAccessWithBackendRoleUser,
123-
ImmutableList.of(opensearchBackendRole)
124-
);
112+
createUser(mlNonOwnerFullAccessWithBackendRoleUser, password, ImmutableList.of(opensearchBackendRole));
125113
mlNonOwnerFullAccessWithBackendRoleClient = new SecureRestClientBuilder(
126114
getClusterHosts().toArray(new HttpHost[0]),
127115
isHttps(),
128116
mlNonOwnerFullAccessWithBackendRoleUser,
129-
mlNonOwnerFullAccessWithBackendRoleUser
117+
password
130118
).setSocketTimeout(60000).build();
131119

132120
createRoleMapping("ml_read_access", ImmutableList.of(mlReadOnlyUser));

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

+15-23
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
5858
public ExpectedException exceptionRule = ExpectedException.none();
5959

6060
private String modelGroupId;
61+
private String password = "IntegTest@SecureMLRestIT123";
6162

6263
@Before
6364
public void setup() throws IOException, ParseException {
@@ -77,37 +78,28 @@ public void setup() throws IOException, ParseException {
7778
}
7879
createSearchRole(indexSearchAccessRole, "*");
7980

80-
createUser(mlNoAccessUser, mlNoAccessUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
81-
mlNoAccessClient = new SecureRestClientBuilder(
82-
getClusterHosts().toArray(new HttpHost[0]),
83-
isHttps(),
84-
mlNoAccessUser,
85-
mlNoAccessUser
86-
).setSocketTimeout(60000).build();
81+
createUser(mlNoAccessUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
82+
mlNoAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlNoAccessUser, password)
83+
.setSocketTimeout(60000)
84+
.build();
8785

88-
createUser(mlReadOnlyUser, mlReadOnlyUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
89-
mlReadOnlyClient = new SecureRestClientBuilder(
90-
getClusterHosts().toArray(new HttpHost[0]),
91-
isHttps(),
92-
mlReadOnlyUser,
93-
mlReadOnlyUser
94-
).setSocketTimeout(60000).build();
86+
createUser(mlReadOnlyUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
87+
mlReadOnlyClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlReadOnlyUser, password)
88+
.setSocketTimeout(60000)
89+
.build();
9590

96-
createUser(mlFullAccessNoIndexAccessUser, mlFullAccessNoIndexAccessUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
91+
createUser(mlFullAccessNoIndexAccessUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
9792
mlFullAccessNoIndexAccessClient = new SecureRestClientBuilder(
9893
getClusterHosts().toArray(new HttpHost[0]),
9994
isHttps(),
10095
mlFullAccessNoIndexAccessUser,
101-
mlFullAccessNoIndexAccessUser
96+
password
10297
).setSocketTimeout(60000).build();
10398

104-
createUser(mlFullAccessUser, mlFullAccessUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
105-
mlFullAccessClient = new SecureRestClientBuilder(
106-
getClusterHosts().toArray(new HttpHost[0]),
107-
isHttps(),
108-
mlFullAccessUser,
109-
mlFullAccessUser
110-
).setSocketTimeout(60000).build();
99+
createUser(mlFullAccessUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
100+
mlFullAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlFullAccessUser, password)
101+
.setSocketTimeout(60000)
102+
.build();
111103

112104
createRoleMapping("ml_read_access", new ArrayList<>(Arrays.asList(mlReadOnlyUser)));
113105
createRoleMapping("ml_full_access", new ArrayList<>(Arrays.asList(mlFullAccessNoIndexAccessUser, mlFullAccessUser)));

0 commit comments

Comments
 (0)