Skip to content

Commit

Permalink
add disjoint sets, fix switch rule
Browse files Browse the repository at this point in the history
  • Loading branch information
srefsland committed Dec 31, 2023
1 parent 2327dee commit 1c86a63
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 110 deletions.
Binary file modified images/losses_7x7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ matplotlib
numpy
tensorflow
tqdm
disjoint-set
2 changes: 2 additions & 0 deletions src/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def _predict_moves(self, X, legal_moves):
else:
prediction = self.litemodel.predict_single(np.squeeze(X, axis=0))

prediction = np.squeeze(prediction, axis=0)

for i in range(len(prediction)):
move = (i // self.board_size, i % self.board_size)
if move not in legal_moves:
Expand Down
8 changes: 4 additions & 4 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Board config
BOARD_SIZE = 7
CLASSIC_DISPLAY = True
SWICH_RULE_ALLOWED = True
SWITCH_RULE_ALLOWED = True

# MCTS config
MCTS_DYNAMIC_SIMS_TIME = 4.0
Expand All @@ -14,8 +14,8 @@

# RL config
NUM_EPISODES = 500
DISPLAY_GAME_RL = True
DISPLAY_GAME_RL_INTERVAL = 50
DISPLAY_GAME_RL = False
DISPLAY_GAME_RL_INTERVAL = 10
REPLAY_BUFFER_SIZE = 2048
MINI_BATCH_SIZE = 256
SAVE_INTERVAL = 50
Expand All @@ -35,7 +35,7 @@
USE_CRITIC = False

# TOPP
MODEL_DIR = "models/best_models"
MODEL_DIR = "models/2023-12-30_16-09-29"
TOPP_TEMPERATURE = 1.0
TOPP_NUM_GAMES = 30
TOPP_VERBOSE = True
Expand Down
9 changes: 6 additions & 3 deletions src/play_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def play_versus_actor(actor, board_display, board_size=4, best_move=True, player1=True):
board = HexStateManager(board_size=board_size, )
board = HexStateManager(board_size=board_size)

new_move = None
is_terminal = False
Expand All @@ -16,7 +16,10 @@ def play_versus_actor(actor, board_display, board_size=4, best_move=True, player
current_player = board.player

# If switch player, then we need to switch the player for the actor as well
if current_player == 1 and player1 and not board.switched:
if (current_player == 1 and player1 and not board.switched
or current_player == -1 and not player1 and not board.switched
or current_player == 1 and not player1 and board.switched
or current_player == -1 and player1 and board.switched):
if config.CLASSIC_DISPLAY:
x = input("Enter position: ")

Expand All @@ -41,7 +44,7 @@ def play_versus_actor(actor, board_display, board_size=4, best_move=True, player
board_display.display_board(board, delay=0.5, winner=current_player)

if __name__ == "__main__":
actor_episodes = 200
actor_episodes = 150

saved_model = f"{config.MODEL_DIR}/model_{config.BOARD_SIZE}x{config.BOARD_SIZE}_{actor_episodes}"
model = BoardGameNetCNN(board_size=config.BOARD_SIZE, bridge_features=config.BRIDGE_FEATURES, saved_model=saved_model)
Expand Down
122 changes: 36 additions & 86 deletions src/statemanager/hexstatemanager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy
import pickle

import numpy as np
from disjoint_set import DisjointSet

from .statemanager import StateManager
import pickle


class HexStateManager(StateManager):
def __init__(self, board_size=6, **kwargs):
Expand Down Expand Up @@ -68,7 +69,15 @@ def make_move(self, move, player=None):

if not (len(self.move_history) == 2 and self.switched):
self.board[move[0]][move[1]] = player
self.player = -1 if player == 1 else 1

neighbors = self._expand_neighbors(move, player)
for neighbor in neighbors:
if player == 1:
self.disjoint_set_red.union(neighbor, move)
else:
self.disjoint_set_blue.union(neighbor, move)

self.player = -player

return move

Expand All @@ -89,7 +98,7 @@ def make_random_move(self, player=None):
if len(moves) == 0:
return

move = self.make_move(moves[np.random.randint(0, len(moves))], player)
move = self.make_move(list(moves)[np.random.randint(0, len(moves))], player)

return move

Expand All @@ -111,7 +120,9 @@ def generate_child_states(self, player=None):
state_manager = self.copy_state_manager()
state_manager.make_move(move, player)

yield state_manager.board, state_manager.player, move
node_player = state_manager.player if not state_manager.switched else -state_manager.player

yield state_manager.board, node_player, move

def check_winning_state(self, player=None):
"""Checks if there is a win in the current state of the board.
Expand All @@ -132,32 +143,6 @@ def check_winning_state(self, player=None):
or self._check_winning_state_player2()
)

def get_winning_moves(self, player=None):
"""Checks if some of the child states results in a win. Useful for
shortening the number of moves in each episode.
Args:
player (int, optional): the player to check winning move for. Defaults to None.
Returns:
list[tuple[int, int]]: the moves that results in a win, None if there are none that results in a win.
"""
if player is None:
player = self.player

moves = self.get_legal_moves()
winning_moves = []

for move in moves:
child_board = self.copy_state_manager()

child_board.make_move(move, player)

if child_board.check_winning_state(player):
winning_moves.append(move)

return None if len(winning_moves) == 0 else winning_moves

def reset(self):
self._initialize_state(self.board_size)

Expand Down Expand Up @@ -200,74 +185,37 @@ def _initialize_state(self, board_size):
self.moves_made = set()
self.move_history = []
self.player = 1

self.top_node = (-1, 0)
self.bottom_node = (board_size, 0)
self.left_node = (0, -1)
self.right_node = (0, board_size)

cells = [(i, j) for j in range(board_size) for i in range(board_size)]
self.disjoint_set_red = DisjointSet(cells + [self.top_node, self.bottom_node])
self.disjoint_set_blue = DisjointSet(cells + [self.left_node, self.right_node])

for i in range(board_size):
self.disjoint_set_red.union((0, i), self.top_node)
self.disjoint_set_red.union((board_size-1, i), self.bottom_node)
self.disjoint_set_blue.union((i, 0), self.left_node)
self.disjoint_set_blue.union((i, board_size-1), self.right_node)

def _check_winning_state_player1(self):
"""Checks the winning state of player 1.
Returns:
bool: true if player 1 has won, false if not.
"""
nodes_to_visit = []
nodes_visited = []

for col in range(len(self.board[0])):
if self.board[0][col] == 1:
nodes_to_visit.append((0, col))

while len(nodes_to_visit) > 0:
node = nodes_to_visit.pop()
nodes_visited.append(node)

if node[0] == self.board_size - 1:
return True

neighbors = self._expand_neighbors(node, player=1)

for neighbor in neighbors:
if neighbor not in nodes_to_visit and neighbor not in nodes_visited:
nodes_to_visit.append(neighbor)

return False
return self.disjoint_set_red.find(self.top_node) == self.disjoint_set_red.find(self.bottom_node)

def _check_winning_state_player2(self):
"""Checks the winning state of player 2.
Returns:
bool: true if player 2 has won, false if not.
"""
nodes_to_visit = []
nodes_visited = []

for row in range(len(self.board)):
if self.board[row][0] == -1:
nodes_to_visit.append((row, 0))

while len(nodes_to_visit) > 0:
node = nodes_to_visit.pop()
nodes_visited.append(node)

if node[1] == self.board_size - 1:
return True

neighbors = self._expand_neighbors(node, player=-1)

for neighbor in neighbors:
if neighbor not in nodes_to_visit and neighbor not in nodes_visited:
nodes_to_visit.append(neighbor)

return False

def _is_within_bounds(self, row, col):
"""Ensures that the current row and column are within the bounds of the board.
Args:
row (int): the row index.
col (int): the column index.
Returns:
bool: true if within bounds, false if not.
"""
return row >= 0 and row < self.board_size and col >= 0 and col < self.board_size
return self.disjoint_set_blue.find(self.left_node) == self.disjoint_set_blue.find(self.right_node)

def _expand_neighbors(self, cell, player=None):
"""Finds neighbors that connect to the current node. Used to determine if the state is terminal (game over).
Expand All @@ -294,9 +242,11 @@ def _expand_neighbors(self, cell, player=None):
]

neighbors = []
# Select everyone but index 1
move_history = self.move_history if not self.switched else self.move_history[0:1] + self.move_history[2:]

for neighbor in neighbors_coords:
if (neighbor, player) in self.move_history:
if (neighbor, player) in move_history:
neighbors.append(neighbor)

return neighbors
4 changes: 0 additions & 4 deletions src/statemanager/statemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ def generate_child_states(self, player):
@abstractmethod
def check_winning_state(self, player):
pass

@abstractmethod
def get_winning_moves(self, player):
pass

@abstractmethod
def reset(self):
Expand Down
13 changes: 7 additions & 6 deletions src/tournament_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def run_tournament(
actors, state_manager, display, num_games=25, board_size=4, temperature=1.0
):
"""Run the TOPP tournament for different actors.
"""Run tournament for different actors.
Args:
actors (list[Actor]): the actors of different playing strengths.
Expand Down Expand Up @@ -102,10 +102,11 @@ def run_game(
else:
current_player = state_manager.player

if current_player == 1:
move = actor1.predict_best_move(state=state_manager.board, player=state_manager.player)
if (current_player == 1 and not state_manager.switched
or current_player == -1 and state_manager.switched):
move = actor1.predict_best_move(state=state_manager.board, player=state_manager.player, legal_moves=state_manager.legal_moves)
else:
move = actor2.predict_best_move(state=state_manager.board, player=state_manager.player)
move = actor2.predict_best_move(state=state_manager.board, player=state_manager.player, legal_moves=state_manager.legal_moves)

move = state_manager.make_move(move)

Expand All @@ -122,14 +123,14 @@ def run_game(
actor2=actor2.name,
)

winner = current_player
winner = current_player if not state_manager.switched else -current_player

return winner


if __name__ == "__main__":
display = HexBoardDisplayClassic() if config.CLASSIC_DISPLAY else HexBoardDisplay()
state_manager = HexStateManager(board_size=config.BOARD_SIZE)
state_manager = HexStateManager(board_size=config.BOARD_SIZE, switch_rule_allowed=config.SWITCH_RULE_ALLOWED)

save_interval = config.SAVE_INTERVAL

Expand Down
10 changes: 3 additions & 7 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from nn.boardgamenetcnn import BoardGameNetCNN
from statemanager.hexstatemanager import HexStateManager

import cProfile
import pstats

def rl_algorithm(actor, state_manager, mcts_state_manager, display):
"""The reinforcement learning algorithm.
Expand All @@ -32,7 +29,7 @@ def rl_algorithm(actor, state_manager, mcts_state_manager, display):

replay_buf = replay_buffer.ReplayBuffer(maxlen=config.REPLAY_BUFFER_SIZE)
i_s = config.SAVE_INTERVAL
time_stamp = datetime.now()
time_stamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
replay_buf.clear()

for g_a in tqdm(range(config.NUM_EPISODES + 1)):
Expand Down Expand Up @@ -112,8 +109,8 @@ def rl_algorithm(actor, state_manager, mcts_state_manager, display):
optimizer=config.ANN_OPTIMIZER,
board_size=config.BOARD_SIZE,
)
state_manager = HexStateManager(config.BOARD_SIZE, switch_rule_allowed=config.SWICH_RULE_ALLOWED)
mcts_state_manager = HexStateManager(config.BOARD_SIZE, switch_rule_allowed=config.SWICH_RULE_ALLOWED)
state_manager = HexStateManager(config.BOARD_SIZE, switch_rule_allowed=config.SWITCH_RULE_ALLOWED)
mcts_state_manager = HexStateManager(config.BOARD_SIZE, switch_rule_allowed=config.SWITCH_RULE_ALLOWED)
display = None if not config.DISPLAY_GAME_RL else HexBoardDisplayClassic() if config.CLASSIC_DISPLAY else HexBoardDisplay()
actor = Actor(
name="actor_rl",
Expand All @@ -126,4 +123,3 @@ def rl_algorithm(actor, state_manager, mcts_state_manager, display):
litemodel=None,
)
rl_algorithm(actor=actor, state_manager=state_manager, mcts_state_manager=mcts_state_manager, display=display)

3 changes: 3 additions & 0 deletions tests/test_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ def test_switch_rule():

board.make_move((0, 1))
assert board.player == 1

with pytest.raises(Exception):
board.make_move((0, 1))

0 comments on commit 1c86a63

Please sign in to comment.