Skip to content

Commit 863255b

Browse files
ylwu-amznzane-neo
authored andcommittedSep 1, 2023
add eligible node role settings (opensearch-project#1197) (opensearch-project#1221)
* add eligible node role settings * add more comment --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 4fe8c46 commit 863255b

28 files changed

+287
-144
lines changed
 

‎plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
129129

130130
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
131131
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
132+
FunctionName functionName = mlModel.getAlgorithm();
132133
modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
133134
if (!access) {
134135
listener
@@ -141,7 +142,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
141142
}
142143
// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
143144
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
144-
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes();
145+
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes(functionName);
145146
Map<String, DiscoveryNode> nodeMapping = new HashMap<>();
146147
for (DiscoveryNode node : allEligibleNodes) {
147148
nodeMapping.put(node.getId(), node);
@@ -161,7 +162,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
161162
nodeIds.add(nodeId);
162163
}
163164
}
164-
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
165+
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName);
165166
if (workerNodes != null && workerNodes.length > 0) {
166167
Set<String> difference = new HashSet<String>(Arrays.asList(workerNodes));
167168
difference.removeAll(Arrays.asList(targetNodeIds));

‎plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.opensearch.action.support.HandledTransportAction;
1111
import org.opensearch.common.inject.Inject;
1212
import org.opensearch.core.action.ActionListener;
13+
import org.opensearch.ml.common.FunctionName;
1314
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
1415
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
1516
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
@@ -42,6 +43,7 @@ public TransportExecuteTaskAction(
4243
@Override
4344
protected void doExecute(Task task, ActionRequest request, ActionListener<MLExecuteTaskResponse> listener) {
4445
MLExecuteTaskRequest mlPredictionTaskRequest = MLExecuteTaskRequest.fromActionRequest(request);
45-
mlExecuteTaskRunner.run(mlPredictionTaskRequest, transportService, listener);
46+
FunctionName functionName = mlPredictionTaskRequest.getFunctionName();
47+
mlExecuteTaskRunner.run(functionName, mlPredictionTaskRequest, transportService, listener);
4648
}
4749
}

‎plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.core.action.ActionListener;
2929
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
3030
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
31+
import org.opensearch.ml.common.FunctionName;
3132
import org.opensearch.ml.common.MLModel;
3233
import org.opensearch.ml.common.MLTask;
3334
import org.opensearch.ml.common.MLTaskState;
@@ -116,26 +117,26 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
116117
switch (requestType) {
117118
case DEPLOY_MODEL_DONE:
118119
Set<String> workNodes = mlTaskManager.getWorkNodes(taskId);
120+
MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId);
121+
FunctionName functionName = mlTaskCache.getMlTask().getFunctionName();
119122
if (workNodes != null) {
120123
workNodes.remove(workerNodeId);
121124
}
122-
123125
if (error != null) {
124126
mlTaskManager.addNodeError(taskId, workerNodeId, error);
125127
} else {
126128
mlModelManager.addModelWorkerNode(modelId, workerNodeId);
127-
syncModelWorkerNodes(modelId);
129+
syncModelWorkerNodes(modelId, functionName);
128130
}
129131

130132
if (workNodes == null || workNodes.size() == 0) {
131-
MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId);
132133
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
133134
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
134135
if (mlTaskCache.allNodeFailed()) {
135136
taskState = MLTaskState.FAILED;
136137
currentWorkerNodeCount = 0;
137138
} else {
138-
syncModelWorkerNodes(modelId);
139+
syncModelWorkerNodes(modelId, functionName);
139140
}
140141
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
141142
builder.put(MLTask.STATE_FIELD, taskState);
@@ -196,9 +197,9 @@ private boolean triggerNextModelDeployAndCheckIfRestRetryTimes(Set<String> workN
196197
return false;
197198
}
198199

199-
private void syncModelWorkerNodes(String modelId) {
200+
private void syncModelWorkerNodes(String modelId, FunctionName functionName) {
200201
DiscoveryNode[] allNodes = nodeHelper.getAllNodes();
201-
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
202+
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName);
202203
if (allNodes.length > 1 && workerNodes != null && workerNodes.length > 0) {
203204
log.debug("Sync to other nodes about worker nodes of model {}: {}", modelId, Arrays.toString(workerNodes));
204205
MLSyncUpInput syncUpInput = MLSyncUpInput.builder().addedWorkerNodes(ImmutableMap.of(modelId, workerNodes)).build();

‎plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+9-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.commons.authuser.User;
1616
import org.opensearch.core.action.ActionListener;
1717
import org.opensearch.core.xcontent.NamedXContentRegistry;
18+
import org.opensearch.ml.common.FunctionName;
1819
import org.opensearch.ml.common.exception.MLValidationException;
1920
import org.opensearch.ml.common.transport.MLTaskResponse;
2021
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
@@ -86,6 +87,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
8687

8788
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
8889
mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> {
90+
FunctionName functionName = mlModel.getAlgorithm();
8991
modelAccessControlHelper
9092
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
9193
if (!access) {
@@ -97,12 +99,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
9799
String requestId = mlPredictionTaskRequest.getRequestID();
98100
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
99101
long startTime = System.nanoTime();
100-
mlPredictTaskRunner.run(mlPredictionTaskRequest, transportService, ActionListener.runAfter(listener, () -> {
101-
long endTime = System.nanoTime();
102-
double durationInMs = (endTime - startTime) / 1e6;
103-
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
104-
log.debug("completed predict request " + requestId + " for model " + modelId);
105-
}));
102+
mlPredictTaskRunner
103+
.run(functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(listener, () -> {
104+
long endTime = System.nanoTime();
105+
double durationInMs = (endTime - startTime) / 1e6;
106+
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
107+
log.debug("completed predict request " + requestId + " for model " + modelId);
108+
}));
106109
}
107110
}, e -> {
108111
log.error("Failed to Validate Access for ModelId " + modelId, e);

‎plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
264264
}));
265265
return;
266266
}
267-
mlTaskDispatcher.dispatch(ActionListener.wrap(node -> {
267+
mlTaskDispatcher.dispatch(registerModelInput.getFunctionName(), ActionListener.wrap(node -> {
268268
String nodeId = node.getId();
269269
mlTask.setWorkerNodes(ImmutableList.of(nodeId));
270270

‎plugin/src/main/java/org/opensearch/ml/action/training/TransportTrainingTaskAction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ public TransportTrainingTaskAction(
3939
@Override
4040
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
4141
MLTrainingTaskRequest trainingRequest = MLTrainingTaskRequest.fromActionRequest(request);
42-
mlTrainingTaskRunner.run(trainingRequest, transportService, listener);
42+
mlTrainingTaskRunner.run(trainingRequest.getMlInput().getFunctionName(), trainingRequest, transportService, listener);
4343
}
4444
}

‎plugin/src/main/java/org/opensearch/ml/action/trainpredict/TransportTrainAndPredictionTaskAction.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ public TransportTrainAndPredictionTaskAction(
3535
@Override
3636
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
3737
MLTrainingTaskRequest trainingRequest = MLTrainingTaskRequest.fromActionRequest(request);
38-
mlTrainAndPredictTaskRunner.run(trainingRequest, transportService, listener);
38+
mlTrainAndPredictTaskRunner.run(trainingRequest.getMlInput().getFunctionName(), trainingRequest, transportService, listener);
3939
}
4040
}

‎plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.opensearch.core.common.io.stream.StreamInput;
3131
import org.opensearch.core.xcontent.NamedXContentRegistry;
3232
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
33+
import org.opensearch.ml.common.FunctionName;
3334
import org.opensearch.ml.common.MLModel;
3435
import org.opensearch.ml.common.model.MLModelState;
3536
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
@@ -238,7 +239,8 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo
238239
String[] removedModelIds = specifiedModelIds ? modelIds : mlModelManager.getAllModelIds();
239240
if (removedModelIds != null) {
240241
for (String modelId : removedModelIds) {
241-
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
242+
FunctionName functionName = mlModelManager.getModelFunctionName(modelId);
243+
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName);
242244
modelWorkerNodesMap.put(modelId, workerNodes);
243245
}
244246
}

‎plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java

+61-24
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
package org.opensearch.ml.cluster;
77

88
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES;
9+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES;
910
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE;
11+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES;
1012

1113
import java.util.ArrayList;
12-
import java.util.Arrays;
1314
import java.util.HashSet;
1415
import java.util.List;
1516
import java.util.Set;
@@ -21,6 +22,7 @@
2122
import org.opensearch.common.settings.Settings;
2223
import org.opensearch.core.common.Strings;
2324
import org.opensearch.ml.common.CommonValue;
25+
import org.opensearch.ml.common.FunctionName;
2426
import org.opensearch.ml.utils.MLNodeUtils;
2527

2628
import lombok.extern.log4j.Log4j2;
@@ -31,6 +33,8 @@ public class DiscoveryNodeHelper {
3133
private final HotDataNodePredicate eligibleNodeFilter;
3234
private volatile Boolean onlyRunOnMLNode;
3335
private volatile Set<String> excludedNodeNames;
36+
private volatile Set<String> remoteModelEligibleNodeRoles;
37+
private volatile Set<String> localModelEligibleNodeRoles;
3438

3539
public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) {
3640
this.clusterService = clusterService;
@@ -41,44 +45,61 @@ public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) {
4145
clusterService
4246
.getClusterSettings()
4347
.addSettingsUpdateConsumer(ML_COMMONS_EXCLUDE_NODE_NAMES, it -> excludedNodeNames = Strings.commaDelimitedListToSet(it));
48+
remoteModelEligibleNodeRoles = new HashSet<>();
49+
remoteModelEligibleNodeRoles.addAll(ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES.get(settings));
50+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES, it -> {
51+
remoteModelEligibleNodeRoles = new HashSet<>(it);
52+
});
53+
localModelEligibleNodeRoles = new HashSet<>();
54+
localModelEligibleNodeRoles.addAll(ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES.get(settings));
55+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES, it -> {
56+
localModelEligibleNodeRoles = new HashSet<>(it);
57+
});
4458
}
4559

46-
public String[] getEligibleNodeIds() {
47-
DiscoveryNode[] nodes = getEligibleNodes();
60+
public String[] getEligibleNodeIds(FunctionName functionName) {
61+
DiscoveryNode[] nodes = getEligibleNodes(functionName);
4862
String[] nodeIds = new String[nodes.length];
4963
for (int i = 0; i < nodes.length; i++) {
5064
nodeIds[i] = nodes[i].getId();
5165
}
5266
return nodeIds;
5367
}
5468

55-
public DiscoveryNode[] getEligibleNodes() {
69+
public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
5670
ClusterState state = this.clusterService.state();
57-
final List<DiscoveryNode> eligibleMLNodes = new ArrayList<>();
58-
final List<DiscoveryNode> eligibleDataNodes = new ArrayList<>();
71+
final List<DiscoveryNode> eligibleNodes = new ArrayList<>();
5972
for (DiscoveryNode node : state.nodes()) {
6073
if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) {
6174
continue;
6275
}
63-
if (MLNodeUtils.isMLNode(node)) {
64-
eligibleMLNodes.add(node);
65-
}
66-
if (!onlyRunOnMLNode && node.isDataNode() && isEligibleDataNode(node)) {
67-
eligibleDataNodes.add(node);
76+
if (functionName == FunctionName.REMOTE) {// remote model
77+
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
78+
} else { // local model
79+
if (onlyRunOnMLNode) {
80+
if (MLNodeUtils.isMLNode(node)) {
81+
eligibleNodes.add(node);
82+
}
83+
} else {
84+
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
85+
}
6886
}
6987
}
70-
if (eligibleMLNodes.size() > 0) {
71-
DiscoveryNode[] mlNodes = eligibleMLNodes.toArray(new DiscoveryNode[0]);
72-
log.debug("Find {} dedicated ML nodes: {}", eligibleMLNodes.size(), Arrays.toString(mlNodes));
73-
return mlNodes;
74-
} else {
75-
DiscoveryNode[] dataNodes = eligibleDataNodes.toArray(new DiscoveryNode[0]);
76-
log.debug("Find no dedicated ML nodes. But have {} data nodes: {}", eligibleDataNodes.size(), Arrays.toString(dataNodes));
77-
return dataNodes;
88+
return eligibleNodes.toArray(new DiscoveryNode[0]);
89+
}
90+
91+
private void getEligibleNodes(Set<String> allowedNodeRoles, List<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
92+
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
93+
eligibleNodes.add(node);
94+
}
95+
for (String nodeRole : allowedNodeRoles) {
96+
if (!"data".equals(nodeRole) && node.getRoles().stream().anyMatch(r -> r.roleName().equals(nodeRole))) {
97+
eligibleNodes.add(node);
98+
}
7899
}
79100
}
80101

81-
public String[] filterEligibleNodes(String[] nodeIds) {
102+
public String[] filterEligibleNodes(FunctionName functionName, String[] nodeIds) {
82103
if (nodeIds == null || nodeIds.length == 0) {
83104
return nodeIds;
84105
}
@@ -88,14 +109,30 @@ public String[] filterEligibleNodes(String[] nodeIds) {
88109
if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) {
89110
continue;
90111
}
91-
if (MLNodeUtils.isMLNode(node)) {
92-
eligibleNodes.add(node.getId());
112+
if (functionName == FunctionName.REMOTE) {// remote model
113+
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
114+
} else { // local model
115+
if (onlyRunOnMLNode) {
116+
if (MLNodeUtils.isMLNode(node)) {
117+
eligibleNodes.add(node.getId());
118+
}
119+
} else {
120+
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
121+
}
93122
}
94-
if (!onlyRunOnMLNode && node.isDataNode() && isEligibleDataNode(node)) {
123+
}
124+
return eligibleNodes.toArray(new String[0]);
125+
}
126+
127+
private void getEligibleNodes(Set<String> allowedNodeRoles, Set<String> eligibleNodes, DiscoveryNode node) {
128+
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
129+
eligibleNodes.add(node.getId());
130+
}
131+
for (String nodeRole : allowedNodeRoles) {
132+
if (!"data".equals(nodeRole) && node.getRoles().stream().anyMatch(r -> r.roleName().equals(nodeRole))) {
95133
eligibleNodes.add(node.getId());
96134
}
97135
}
98-
return eligibleNodes.toArray(new String[0]);
99136
}
100137

101138
public DiscoveryNode[] getAllNodes() {

0 commit comments

Comments
 (0)
Please sign in to comment.