-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnode.py
299 lines (241 loc) · 10.8 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import numpy as np
import math
class Node:
"""
A standard MCTS Node implementation.
Each node stores:
- The current game state.
- Its parent node and the action taken to reach it.
- A list of its children nodes.
- The number of visits and cumulative value.
- Possible moves that can still be expanded.
"""
def __init__(self, game, args, state, parent=None, action_taken=None):
"""
Initializes the Node with the given state, game, and MCTS arguments.
:param game: The game logic object, which provides functions like get_valid_moves, etc.
:param args: A dictionary of MCTS parameters, e.g., exploration constant (C).
:param state: The current game state at this node.
:param parent: (Optional) The parent node in the tree.
:param action_taken: (Optional) The action leading from the parent node to this node.
"""
self.game = game
self.args = args
self.state = state
self.parent = parent
self.action_taken = action_taken
# Children and moves
self.children = []
# A binary array indicating which actions are valid and not yet expanded
self.expandable_moves = game.get_valid_moves(state)
# MCTS statistics
self.visit_count = 0 # How many times we have visited this node
self.value_sum = 0 # Sum of values from simulations
def is_fully_expanded(self):
"""
Checks if there are no remaining moves to expand and
if this node has at least one child.
:return: True if no moves remain and the node already has children.
"""
return np.all(self.expandable_moves == 0) and len(self.children) > 0
def is_leaf(self):
"""
Checks if the node has no children.
:return: True if the node is a leaf.
"""
return len(self.children) == 0
def select(self):
"""
Selection step: choose the child with the highest UCB value.
:return: The child node that maximizes UCB.
"""
best_child = None
best_ucb = -np.inf
# Iterate over children and compute their UCB score
for child in self.children:
ucb = child.get_ucb(child)
if ucb > best_ucb:
best_child = child
best_ucb = ucb
return best_child
def get_ucb(self, child):
"""
Computes the UCB (Upper Confidence Bound) for the given child.
q_value is scaled so that a higher q_value means a better outcome.
:param child: The child node for which we calculate UCB.
:return: The UCB score of the child.
"""
# Convert child's average value to a Q-value (1 - average_value_in_[0,1])
q_value = 1 - (child.value_sum / child.visit_count + 1) / 2
# The exploration term includes log of parent's visits and child's visits
return q_value + self.args["C"] * np.sqrt(np.log(self.visit_count) / child.visit_count)
def expand(self):
"""
Expansion step: randomly choose one valid move to expand.
Creates a new child node and updates this node's expandable moves.
:return: The newly created child node.
"""
# Randomly pick one valid action among the unexpanded moves
action = np.random.choice(np.where(self.expandable_moves == 1)[0])
# Mark the chosen move as expanded
self.expandable_moves[action] = 0
# Compute the next state for the chosen action.
child_state = self.state.copy()
child_state = self.game.get_next_state(child_state, action, 1)
# Adjust perspective if necessary (for two-player alternation)
child_state = self.game.change_perspective(child_state, self.game.get_opponent(1))
# Create a new child node and append it to this node's children
child = Node(self.game, self.args, child_state, self, action)
self.children.append(child)
return child
def simulate(self):
"""
Simulation (rollout) step: Perform a random rollout from the current state
until reaching a terminal state. Returns the result (win/loss/draw) value
from the current player's perspective.
:return: The value (float) of the final outcome for the current node's player.
"""
# Check if the current node's state is already terminal
value, is_terminal = self.game.get_value_and_terminated(self.state, None)
# Flip perspective if needed
value = self.game.get_opponent_value(value)
if is_terminal:
return value
# Otherwise, perform a random simulation until terminal
rollout_state = self.state.copy()
rollout_player = 1
while True:
valid_moves = self.game.get_valid_moves(rollout_state)
# Randomly pick a valid action to simulate
action = np.random.choice(np.where(valid_moves == 1)[0])
rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
# Check if this new state is terminal
value, is_terminal = self.game.get_value_and_terminated(rollout_state, action)
if is_terminal:
# If the last player was -1, we invert perspective
if rollout_player == -1:
value = self.game.get_opponent_value(value)
return value
# Switch player for the next turn
rollout_player = self.game.get_opponent(rollout_player)
def backpropagate(self, value):
"""
Backpropagation step: increment the visit count, add the simulation result
to the value sum, and recursively update the parent until reaching the root.
:param value: The simulation outcome (float) for the current node's player.
"""
# Update current node's statistics
self.visit_count += 1
self.value_sum += value
# Invert perspective for the parent
value = self.game.get_opponent_value(value)
# Recursively backpropagate up the tree
if self.parent is not None:
self.parent.backpropagate(value)
class AlphaZeroNode:
"""
A node structure that incorporates prior probabilities (policy) from
a neural network (as in AlphaZero). UCB is adjusted to include priors.
"""
def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
"""
Initialize the AlphaZeroNode.
:param game: The game logic object.
:param args: A dictionary of parameters (e.g., C for exploration).
:param state: The current game state.
:param parent: (Optional) The parent node in the tree.
:param action_taken: (Optional) The action that led from the parent to this node.
:param prior: The prior probability of taking action 'action_taken'.
:param visit_count: Initial visit count (usually 0, except for root).
"""
self.game = game
self.args = args
self.state = state
self.parent = parent
self.action_taken = action_taken
self.prior = prior # Policy prior from the NN for this action
self.children = []
# MCTS statistics
self.visit_count = visit_count
self.value_sum = 0
def is_fully_expanded(self):
"""
In the AlphaZero version, a node is considered fully expanded if it has
at least one child, because the network policy can propose all actions at once.
:return: True if there is at least one child.
"""
return len(self.children) > 0
def select(self):
"""
Selection step, adjusted for AlphaZero logic using 'prior' in the UCB formula.
:return: The child node that has the highest AlphaZero-based UCB value.
"""
best_child = None
best_ucb = -np.inf
# Evaluate each child's UCB and pick the highest
for child in self.children:
ucb = self.get_ucb(child)
if ucb > best_ucb:
best_child = child
best_ucb = ucb
return best_child
def get_ucb(self, child):
"""
Compute the UCB value for the AlphaZero approach:
UCB = Q(s,a) + c * P(s,a) * sqrt(N(s)) / (1 + N(s,a))
Where:
- Q(s,a) is estimated as (1 - average_value_in_[0,1]) here for convenience.
- P(s,a) is the 'prior' from the neural network policy.
- N(s) is the parent's visit_count.
- N(s,a) is the child's visit_count.
:param child: The child node for which we calculate UCB.
:return: The UCB score of the child.
"""
# If child not visited, its average value is 0
if child.visit_count == 0:
q_value = 0
else:
q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
# Combine Q-value and exploration term weighted by the prior
return (
q_value
+ self.args['C']
* (math.sqrt(self.visit_count) / (child.visit_count + 1))
* child.prior
)
def expand(self, policy):
"""
Expand the node by creating a child for each possible action with non-zero probability.
:param policy: An array of probabilities from the network for each possible action.
:return: The last created child node (optional usage).
"""
# For each action, if the probability is > 0, add a child node
for action, prob in enumerate(policy):
if prob > 0:
child_state = self.state.copy()
child_state = self.game.get_next_state(child_state, action, 1)
# Switch perspective to opponent
child_state = self.game.change_perspective(child_state, player=-1)
# Create the child node and store the prior
child = AlphaZeroNode(
self.game, self.args, child_state,
parent=self,
action_taken=action,
prior=prob
)
self.children.append(child)
# Return the last child created (not always necessary)
return child
def backpropagate(self, value):
"""
Update the node's statistics and recursively backpropagate the inverted value to the parent.
:param value: The value from the perspective of the current node's player.
"""
# Update this node's stats
self.value_sum += value
self.visit_count += 1
# Flip perspective for the parent node
value = self.game.get_opponent_value(value)
# Continue backpropagation if there is a parent node
if self.parent is not None:
self.parent.backpropagate(value)