Skip to content

Commit ac04f07

Browse files
add setting to allow private IP (opensearch-project#2534) (opensearch-project#2535)
* add setting to allow private IP Signed-off-by: Yaliang Wu <ylwu@amazon.com> * fix ut Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com> (cherry picked from commit 06d1742) Co-authored-by: Yaliang Wu <ylwu@amazon.com>
1 parent 1d36acc commit ac04f07

File tree

11 files changed

+132
-25
lines changed

11 files changed

+132
-25
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.util.Locale;
1717
import java.util.Map;
1818
import java.util.concurrent.CompletableFuture;
19+
import java.util.concurrent.atomic.AtomicBoolean;
1920

2021
import org.apache.logging.log4j.Logger;
2122
import org.opensearch.client.Client;
@@ -62,6 +63,8 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor {
6263
@Setter
6364
@Getter
6465
private MLGuard mlGuard;
66+
@Setter
67+
private volatile AtomicBoolean connectorPrivateIpEnabled;
6568

6669
private SdkAsyncHttpClient httpClient;
6770

@@ -136,6 +139,6 @@ private void validateHttpClientParameters(String action, Map<String, String> par
136139
String protocol = url.getProtocol();
137140
String host = url.getHost();
138141
int port = url.getPort();
139-
MLHttpClientFactory.validate(protocol, host, port);
142+
MLHttpClientFactory.validate(protocol, host, port, connectorPrivateIpEnabled);
140143
}
141144
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import java.util.Locale;
1616
import java.util.Map;
1717
import java.util.Optional;
18+
import java.util.concurrent.atomic.AtomicBoolean;
1819

1920
import org.apache.logging.log4j.Logger;
2021
import org.opensearch.ExceptionsHelper;
@@ -146,6 +147,8 @@ default void setScriptService(ScriptService scriptService) {}
146147

147148
default void setClient(Client client) {}
148149

150+
default void setConnectorPrivateIpEnabled(AtomicBoolean connectorPrivateIpEnabled) {}
151+
149152
default void setXContentRegistry(NamedXContentRegistry xContentRegistry) {}
150153

151154
default void setClusterService(ClusterService clusterService) {}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
99

1010
import java.util.Map;
11+
import java.util.concurrent.atomic.AtomicBoolean;
1112

1213
import org.opensearch.client.Client;
1314
import org.opensearch.cluster.service.ClusterService;
@@ -43,6 +44,7 @@ public class RemoteModel implements Predictable {
4344
public static final String RATE_LIMITER = "rate_limiter";
4445
public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map";
4546
public static final String GUARDRAILS = "guardrails";
47+
public static final String CONNECTOR_PRIVATE_IP_ENABLED = "connectorPrivateIpEnabled";
4648

4749
private RemoteConnectorExecutor connectorExecutor;
4850

@@ -101,6 +103,7 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
101103
this.connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER));
102104
this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
103105
this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
106+
this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
104107
} catch (RuntimeException e) {
105108
log.error("Failed to init remote model.", e);
106109
throw e;

ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java

+8-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import java.time.Duration;
1515
import java.util.Arrays;
1616
import java.util.Locale;
17+
import java.util.concurrent.atomic.AtomicBoolean;
1718

1819
import lombok.extern.log4j.Log4j2;
1920
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
@@ -43,9 +44,11 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout,
4344
* @param protocol The protocol supported in remote inference, currently only http and https are supported.
4445
* @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost.
4546
* @param port The port number of the remote inference server, port number must be in range [0, 65536].
46-
* @throws UnknownHostException
47+
* @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536].
48+
* @throws UnknownHostException Allow to use private IP or not.
4749
*/
48-
public static void validate(String protocol, String host, int port) throws UnknownHostException {
50+
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
51+
throws UnknownHostException {
4952
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
5053
log.error("Remote inference protocol is not http or https: " + protocol);
5154
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
@@ -62,12 +65,12 @@ public static void validate(String protocol, String host, int port) throws Unkno
6265
log.error("Remote inference port out of range: " + port);
6366
throw new IllegalArgumentException("Port out of range: " + port);
6467
}
65-
validateIp(host);
68+
validateIp(host, connectorPrivateIpEnabled);
6669
}
6770

68-
private static void validateIp(String hostName) throws UnknownHostException {
71+
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
6972
InetAddress[] addresses = InetAddress.getAllByName(hostName);
70-
if (hasPrivateIpAddress(addresses)) {
73+
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
7174
log.error("Remote inference host name has private ip address: " + hostName);
7275
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
7376
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java

+49
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
package org.opensearch.ml.engine.algorithms.remote;
77

88
import static org.junit.Assert.assertEquals;
9+
import static org.mockito.ArgumentMatchers.any;
10+
import static org.mockito.Mockito.never;
911
import static org.mockito.Mockito.times;
1012
import static org.mockito.Mockito.verify;
1113
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
1214

1315
import java.lang.reflect.Field;
1416
import java.util.Arrays;
1517
import java.util.HashMap;
18+
import java.util.concurrent.atomic.AtomicBoolean;
1619

1720
import org.junit.Before;
1821
import org.junit.Rule;
@@ -102,6 +105,52 @@ public void invokeRemoteService_invalidIpAddress() {
102105
assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage());
103106
}
104107

108+
@Test
109+
public void invokeRemoteService_EnabledPrivateIpAddress() {
110+
ConnectorAction predictAction = ConnectorAction
111+
.builder()
112+
.actionType(PREDICT)
113+
.method("POST")
114+
.url("http://127.0.0.1/mock")
115+
.requestBody("{\"input\": \"${parameters.input}\"}")
116+
.build();
117+
Connector connector = HttpConnector
118+
.builder()
119+
.name("test connector")
120+
.version("1")
121+
.protocol("http")
122+
.actions(Arrays.asList(predictAction))
123+
.build();
124+
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
125+
AtomicBoolean privateIpEnabled = new AtomicBoolean(true);
126+
executor.setConnectorPrivateIpEnabled(privateIpEnabled);
127+
executor
128+
.invokeRemoteService(
129+
PREDICT.name(),
130+
createMLInput(),
131+
new HashMap<>(),
132+
"{\"input\": \"hello world\"}",
133+
new ExecutionContext(0),
134+
actionListener
135+
);
136+
Mockito.verify(actionListener, never()).onFailure(any());
137+
138+
privateIpEnabled.set(false);
139+
executor
140+
.invokeRemoteService(
141+
PREDICT.name(),
142+
createMLInput(),
143+
new HashMap<>(),
144+
"{\"input\": \"hello world\"}",
145+
new ExecutionContext(0),
146+
actionListener
147+
);
148+
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
149+
Mockito.verify(actionListener, times(1)).onFailure(captor.capture());
150+
assert captor.getValue() instanceof IllegalArgumentException;
151+
assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage());
152+
}
153+
105154
@Test
106155
public void invokeRemoteService_Empty_payload() {
107156
ConnectorAction predictAction = ConnectorAction

ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java

+26-11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.junit.Assert.assertNotNull;
99

1010
import java.time.Duration;
11+
import java.util.concurrent.atomic.AtomicBoolean;
1112

1213
import org.junit.Rule;
1314
import org.junit.Test;
@@ -28,70 +29,84 @@ public void test_getSdkAsyncHttpClient_success() {
2829

2930
@Test
3031
public void test_validateIp_validIp_noException() throws Exception {
31-
MLHttpClientFactory.validate("http", "api.openai.com", 80);
32+
AtomicBoolean privateIpEnabled = new AtomicBoolean(false);
33+
MLHttpClientFactory.validate("http", "api.openai.com", 80, privateIpEnabled);
3234
}
3335

3436
@Test
3537
public void test_validateIp_rarePrivateIp_throwException() throws Exception {
38+
AtomicBoolean privateIpEnabled = new AtomicBoolean(false);
3639
try {
37-
MLHttpClientFactory.validate("http", "0254.020.00.01", 80);
40+
MLHttpClientFactory.validate("http", "0254.020.00.01", 80, privateIpEnabled);
3841
} catch (IllegalArgumentException e) {
3942
assertNotNull(e);
4043
}
4144

4245
try {
43-
MLHttpClientFactory.validate("http", "172.1048577", 80);
46+
MLHttpClientFactory.validate("http", "172.1048577", 80, privateIpEnabled);
4447
} catch (Exception e) {
4548
assertNotNull(e);
4649
}
4750

4851
try {
49-
MLHttpClientFactory.validate("http", "2886729729", 80);
52+
MLHttpClientFactory.validate("http", "2886729729", 80, privateIpEnabled);
5053
} catch (IllegalArgumentException e) {
5154
assertNotNull(e);
5255
}
5356

5457
try {
55-
MLHttpClientFactory.validate("http", "192.11010049", 80);
58+
MLHttpClientFactory.validate("http", "192.11010049", 80, privateIpEnabled);
5659
} catch (IllegalArgumentException e) {
5760
assertNotNull(e);
5861
}
5962

6063
try {
61-
MLHttpClientFactory.validate("http", "3232300545", 80);
64+
MLHttpClientFactory.validate("http", "3232300545", 80, privateIpEnabled);
6265
} catch (IllegalArgumentException e) {
6366
assertNotNull(e);
6467
}
6568

6669
try {
67-
MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80);
70+
MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80, privateIpEnabled);
6871
} catch (IllegalArgumentException e) {
6972
assertNotNull(e);
7073
}
7174

7275
try {
73-
MLHttpClientFactory.validate("http", "153.24.76.232", 80);
76+
MLHttpClientFactory.validate("http", "153.24.76.232", 80, privateIpEnabled);
7477
} catch (IllegalArgumentException e) {
7578
assertNotNull(e);
7679
}
7780
}
7881

82+
@Test
83+
public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception {
84+
AtomicBoolean privateIpEnabled = new AtomicBoolean(true);
85+
MLHttpClientFactory.validate("http", "0254.020.00.01", 80, privateIpEnabled);
86+
MLHttpClientFactory.validate("http", "172.1048577", 80, privateIpEnabled);
87+
MLHttpClientFactory.validate("http", "2886729729", 80, privateIpEnabled);
88+
MLHttpClientFactory.validate("http", "192.11010049", 80, privateIpEnabled);
89+
MLHttpClientFactory.validate("http", "3232300545", 80, privateIpEnabled);
90+
MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80, privateIpEnabled);
91+
MLHttpClientFactory.validate("http", "153.24.76.232", 80, privateIpEnabled);
92+
}
93+
7994
@Test
8095
public void test_validateSchemaAndPort_success() throws Exception {
81-
MLHttpClientFactory.validate("http", "api.openai.com", 80);
96+
MLHttpClientFactory.validate("http", "api.openai.com", 80, new AtomicBoolean(false));
8297
}
8398

8499
@Test
85100
public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception {
86101
expectedException.expect(IllegalArgumentException.class);
87-
MLHttpClientFactory.validate("ftp", "api.openai.com", 80);
102+
MLHttpClientFactory.validate("ftp", "api.openai.com", 80, new AtomicBoolean(false));
88103
}
89104

90105
@Test
91106
public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception {
92107
expectedException.expect(IllegalArgumentException.class);
93108
expectedException.expectMessage("Port out of range: 65537");
94-
MLHttpClientFactory.validate("https", "api.openai.com", 65537);
109+
MLHttpClientFactory.validate("https", "api.openai.com", 65537, new AtomicBoolean(false));
95110
}
96111

97112
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES;
2828
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT;
2929
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE;
30+
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CONNECTOR_PRIVATE_IP_ENABLED;
3031
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.GUARDRAILS;
3132
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.RATE_LIMITER;
3233
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE;
@@ -128,6 +129,7 @@
128129
import org.opensearch.ml.engine.indices.MLIndicesHandler;
129130
import org.opensearch.ml.engine.utils.FileUtils;
130131
import org.opensearch.ml.profile.MLModelProfile;
132+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
131133
import org.opensearch.ml.stats.ActionName;
132134
import org.opensearch.ml.stats.MLActionLevelStat;
133135
import org.opensearch.ml.stats.MLNodeLevelStat;
@@ -169,6 +171,7 @@ public class MLModelManager {
169171
private final MLTaskManager mlTaskManager;
170172
private final MLEngine mlEngine;
171173
private final DiscoveryNodeHelper nodeHelper;
174+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
172175

173176
private volatile Integer maxModelPerNode;
174177
private volatile Integer maxRegisterTasksPerNode;
@@ -198,7 +201,8 @@ public MLModelManager(
198201
MLTaskManager mlTaskManager,
199202
MLModelCacheHelper modelCacheHelper,
200203
MLEngine mlEngine,
201-
DiscoveryNodeHelper nodeHelper
204+
DiscoveryNodeHelper nodeHelper,
205+
MLFeatureEnabledSetting mlFeatureEnabledSetting
202206
) {
203207
this.client = client;
204208
this.threadPool = threadPool;
@@ -213,6 +217,7 @@ public MLModelManager(
213217
this.mlTaskManager = mlTaskManager;
214218
this.mlEngine = mlEngine;
215219
this.nodeHelper = nodeHelper;
220+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
216221

217222
this.maxModelPerNode = ML_COMMONS_MAX_MODELS_PER_NODE.get(settings);
218223
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MAX_MODELS_PER_NODE, it -> maxModelPerNode = it);
@@ -1170,6 +1175,7 @@ private Map<String, Object> setUpParameterMap(String modelId) {
11701175
params.put(GUARDRAILS, mlGuard);
11711176
log.info("Setting up ML guard parameter for ML predictor.");
11721177
}
1178+
params.put(CONNECTOR_PRIVATE_IP_ENABLED, mlFeatureEnabledSetting.isConnectorPrivateIpEnabled());
11731179
return Collections.unmodifiableMap(params);
11741180
}
11751181

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+9-6
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,11 @@ public Collection<Object> createComponents(
496496
mlIndicesHandler = new MLIndicesHandler(clusterService, client);
497497
mlTaskManager = new MLTaskManager(client, threadPool, mlIndicesHandler);
498498
modelHelper = new ModelHelper(mlEngine);
499+
500+
mlInputDatasetHandler = new MLInputDatasetHandler(client);
501+
modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings);
502+
connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings);
503+
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);
499504
mlModelManager = new MLModelManager(
500505
clusterService,
501506
scriptService,
@@ -510,12 +515,9 @@ public Collection<Object> createComponents(
510515
mlTaskManager,
511516
modelCacheHelper,
512517
mlEngine,
513-
nodeHelper
518+
nodeHelper,
519+
mlFeatureEnabledSetting
514520
);
515-
mlInputDatasetHandler = new MLInputDatasetHandler(client);
516-
modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings);
517-
connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings);
518-
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);
519521

520522
mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper);
521523

@@ -929,7 +931,8 @@ public List<Setting<?>> getSettings() {
929931
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED,
930932
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
931933
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
932-
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE
934+
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE,
935+
MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED
933936
);
934937
return settings;
935938
}

plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java

+3
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,7 @@ private MLCommonsSettings() {}
184184
// This setting is to enable/disable agent related API register/execute/delete/get/search agent.
185185
public static final Setting<Boolean> ML_COMMONS_AGENT_FRAMEWORK_ENABLED = Setting
186186
.boolSetting("plugins.ml_commons.agent_framework_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
187+
188+
public static final Setting<Boolean> ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED = Setting
189+
.boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
187190
}

0 commit comments

Comments
 (0)