|
14 | 14 | import static org.mockito.Mockito.mock;
|
15 | 15 | import static org.mockito.Mockito.verify;
|
16 | 16 | import static org.mockito.Mockito.when;
|
| 17 | +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; |
17 | 18 | import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
|
18 | 19 | import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
|
19 | 20 | import static org.opensearch.ml.utils.TestHelper.clusterSetting;
|
@@ -155,12 +156,14 @@ public void setup() throws IOException {
|
155 | 156 | settings = Settings
|
156 | 157 | .builder()
|
157 | 158 | .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex)
|
| 159 | + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true) |
158 | 160 | .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES)
|
159 | 161 | .build();
|
160 | 162 | threadContext = new ThreadContext(settings);
|
161 | 163 | ClusterSettings clusterSettings = clusterSetting(
|
162 | 164 | settings,
|
163 | 165 | ML_COMMONS_TRUSTED_URL_REGEX,
|
| 166 | + ML_COMMONS_ALLOW_MODEL_URL, |
164 | 167 | ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX
|
165 | 168 | );
|
166 | 169 | when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
|
@@ -294,6 +297,50 @@ public void testDoExecute_invalidURL() {
|
294 | 297 | assertEquals("URL can't match trusted url regex", argumentCaptor.getValue().getMessage());
|
295 | 298 | }
|
296 | 299 |
|
| 300 | + public void testRegisterModelUrlNotAllowed() throws Exception { |
| 301 | + Settings settings = Settings |
| 302 | + .builder() |
| 303 | + .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) |
| 304 | + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false) |
| 305 | + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) |
| 306 | + .build(); |
| 307 | + ClusterSettings clusterSettings = clusterSetting( |
| 308 | + settings, |
| 309 | + ML_COMMONS_TRUSTED_URL_REGEX, |
| 310 | + ML_COMMONS_ALLOW_MODEL_URL, |
| 311 | + ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX |
| 312 | + ); |
| 313 | + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); |
| 314 | + when(clusterService.getSettings()).thenReturn(settings); |
| 315 | + transportRegisterModelAction = new TransportRegisterModelAction( |
| 316 | + transportService, |
| 317 | + actionFilters, |
| 318 | + modelHelper, |
| 319 | + mlIndicesHandler, |
| 320 | + mlModelManager, |
| 321 | + mlTaskManager, |
| 322 | + clusterService, |
| 323 | + settings, |
| 324 | + threadPool, |
| 325 | + client, |
| 326 | + nodeFilter, |
| 327 | + mlTaskDispatcher, |
| 328 | + mlStats, |
| 329 | + modelAccessControlHelper, |
| 330 | + connectorAccessControlHelper, |
| 331 | + mlModelGroupManager |
| 332 | + ); |
| 333 | + |
| 334 | + IllegalArgumentException e = assertThrows( |
| 335 | + IllegalArgumentException.class, |
| 336 | + () -> transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener) |
| 337 | + ); |
| 338 | + assertEquals( |
| 339 | + e.getMessage(), |
| 340 | + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." |
| 341 | + ); |
| 342 | + } |
| 343 | + |
297 | 344 | public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() {
|
298 | 345 | when(node1.getId()).thenReturn("NodeId1");
|
299 | 346 | when(node2.getId()).thenReturn("NodeId2");
|
|
0 commit comments