6
6
package org .opensearch .ml .cluster ;
7
7
8
8
import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_EXCLUDE_NODE_NAMES ;
9
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES ;
9
10
import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_ONLY_RUN_ON_ML_NODE ;
11
+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES ;
10
12
11
13
import java .util .ArrayList ;
12
- import java .util .Arrays ;
13
14
import java .util .HashSet ;
14
15
import java .util .List ;
15
16
import java .util .Set ;
21
22
import org .opensearch .common .settings .Settings ;
22
23
import org .opensearch .core .common .Strings ;
23
24
import org .opensearch .ml .common .CommonValue ;
25
+ import org .opensearch .ml .common .FunctionName ;
24
26
import org .opensearch .ml .utils .MLNodeUtils ;
25
27
26
28
import lombok .extern .log4j .Log4j2 ;
@@ -31,6 +33,8 @@ public class DiscoveryNodeHelper {
31
33
private final HotDataNodePredicate eligibleNodeFilter ;
32
34
private volatile Boolean onlyRunOnMLNode ;
33
35
private volatile Set <String > excludedNodeNames ;
36
+ private volatile Set <String > remoteModelEligibleNodeRoles ;
37
+ private volatile Set <String > localModelEligibleNodeRoles ;
34
38
35
39
public DiscoveryNodeHelper (ClusterService clusterService , Settings settings ) {
36
40
this .clusterService = clusterService ;
@@ -41,44 +45,61 @@ public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) {
41
45
clusterService
42
46
.getClusterSettings ()
43
47
.addSettingsUpdateConsumer (ML_COMMONS_EXCLUDE_NODE_NAMES , it -> excludedNodeNames = Strings .commaDelimitedListToSet (it ));
48
+ remoteModelEligibleNodeRoles = new HashSet <>();
49
+ remoteModelEligibleNodeRoles .addAll (ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES .get (settings ));
50
+ clusterService .getClusterSettings ().addSettingsUpdateConsumer (ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES , it -> {
51
+ remoteModelEligibleNodeRoles = new HashSet <>(it );
52
+ });
53
+ localModelEligibleNodeRoles = new HashSet <>();
54
+ localModelEligibleNodeRoles .addAll (ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES .get (settings ));
55
+ clusterService .getClusterSettings ().addSettingsUpdateConsumer (ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES , it -> {
56
+ localModelEligibleNodeRoles = new HashSet <>(it );
57
+ });
44
58
}
45
59
46
- public String [] getEligibleNodeIds () {
47
- DiscoveryNode [] nodes = getEligibleNodes ();
60
+ public String [] getEligibleNodeIds (FunctionName functionName ) {
61
+ DiscoveryNode [] nodes = getEligibleNodes (functionName );
48
62
String [] nodeIds = new String [nodes .length ];
49
63
for (int i = 0 ; i < nodes .length ; i ++) {
50
64
nodeIds [i ] = nodes [i ].getId ();
51
65
}
52
66
return nodeIds ;
53
67
}
54
68
55
- public DiscoveryNode [] getEligibleNodes () {
69
+ public DiscoveryNode [] getEligibleNodes (FunctionName functionName ) {
56
70
ClusterState state = this .clusterService .state ();
57
- final List <DiscoveryNode > eligibleMLNodes = new ArrayList <>();
58
- final List <DiscoveryNode > eligibleDataNodes = new ArrayList <>();
71
+ final List <DiscoveryNode > eligibleNodes = new ArrayList <>();
59
72
for (DiscoveryNode node : state .nodes ()) {
60
73
if (excludedNodeNames != null && excludedNodeNames .contains (node .getName ())) {
61
74
continue ;
62
75
}
63
- if (MLNodeUtils .isMLNode (node )) {
64
- eligibleMLNodes .add (node );
65
- }
66
- if (!onlyRunOnMLNode && node .isDataNode () && isEligibleDataNode (node )) {
67
- eligibleDataNodes .add (node );
76
+ if (functionName == FunctionName .REMOTE ) {// remote model
77
+ getEligibleNodes (remoteModelEligibleNodeRoles , eligibleNodes , node );
78
+ } else { // local model
79
+ if (onlyRunOnMLNode ) {
80
+ if (MLNodeUtils .isMLNode (node )) {
81
+ eligibleNodes .add (node );
82
+ }
83
+ } else {
84
+ getEligibleNodes (localModelEligibleNodeRoles , eligibleNodes , node );
85
+ }
68
86
}
69
87
}
70
- if (eligibleMLNodes .size () > 0 ) {
71
- DiscoveryNode [] mlNodes = eligibleMLNodes .toArray (new DiscoveryNode [0 ]);
72
- log .debug ("Find {} dedicated ML nodes: {}" , eligibleMLNodes .size (), Arrays .toString (mlNodes ));
73
- return mlNodes ;
74
- } else {
75
- DiscoveryNode [] dataNodes = eligibleDataNodes .toArray (new DiscoveryNode [0 ]);
76
- log .debug ("Find no dedicated ML nodes. But have {} data nodes: {}" , eligibleDataNodes .size (), Arrays .toString (dataNodes ));
77
- return dataNodes ;
88
+ return eligibleNodes .toArray (new DiscoveryNode [0 ]);
89
+ }
90
+
91
+ private void getEligibleNodes (Set <String > allowedNodeRoles , List <DiscoveryNode > eligibleNodes , DiscoveryNode node ) {
92
+ if (allowedNodeRoles .contains ("data" ) && isEligibleDataNode (node )) {
93
+ eligibleNodes .add (node );
94
+ }
95
+ for (String nodeRole : allowedNodeRoles ) {
96
+ if (!"data" .equals (nodeRole ) && node .getRoles ().stream ().anyMatch (r -> r .roleName ().equals (nodeRole ))) {
97
+ eligibleNodes .add (node );
98
+ }
78
99
}
79
100
}
80
101
81
- public String [] filterEligibleNodes (String [] nodeIds ) {
102
+ public String [] filterEligibleNodes (FunctionName functionName , String [] nodeIds ) {
82
103
if (nodeIds == null || nodeIds .length == 0 ) {
83
104
return nodeIds ;
84
105
}
@@ -88,14 +109,30 @@ public String[] filterEligibleNodes(String[] nodeIds) {
88
109
if (excludedNodeNames != null && excludedNodeNames .contains (node .getName ())) {
89
110
continue ;
90
111
}
91
- if (MLNodeUtils .isMLNode (node )) {
92
- eligibleNodes .add (node .getId ());
112
+ if (functionName == FunctionName .REMOTE ) {// remote model
113
+ getEligibleNodes (remoteModelEligibleNodeRoles , eligibleNodes , node );
114
+ } else { // local model
115
+ if (onlyRunOnMLNode ) {
116
+ if (MLNodeUtils .isMLNode (node )) {
117
+ eligibleNodes .add (node .getId ());
118
+ }
119
+ } else {
120
+ getEligibleNodes (localModelEligibleNodeRoles , eligibleNodes , node );
121
+ }
93
122
}
94
- if (!onlyRunOnMLNode && node .isDataNode () && isEligibleDataNode (node )) {
123
+ }
124
+ return eligibleNodes .toArray (new String [0 ]);
125
+ }
126
+
127
+ private void getEligibleNodes (Set <String > allowedNodeRoles , Set <String > eligibleNodes , DiscoveryNode node ) {
128
+ if (allowedNodeRoles .contains ("data" ) && isEligibleDataNode (node )) {
129
+ eligibleNodes .add (node .getId ());
130
+ }
131
+ for (String nodeRole : allowedNodeRoles ) {
132
+ if (!"data" .equals (nodeRole ) && node .getRoles ().stream ().anyMatch (r -> r .roleName ().equals (nodeRole ))) {
95
133
eligibleNodes .add (node .getId ());
96
134
}
97
135
}
98
- return eligibleNodes .toArray (new String [0 ]);
99
136
}
100
137
101
138
public DiscoveryNode [] getAllNodes () {
0 commit comments