-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestHaxball.py
94 lines (81 loc) · 3.12 KB
/
testHaxball.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import os
from typing import List
import pygame
from InputManager import InputManager
import trainingConfig
from GameController import GameController
from AgentInput import AgentInput
from HaxballEngine import GameEngine
import rlSamples.PPO.PPO as PPO
def startUserGameplay(args):
# Check if models dir exists
agentsInTeam: int = 1
total_test_episodes = 5
# Initialize game
gameController: GameController = GameController(agentsInTeam, screen=True)
phase = 0
# Initialize inputs
agentsInputs: List[AgentInput] = [AgentInput() for _ in range(agentsInTeam * 2)]
# Plots data
frameId: int = 0
state0 = gameController.getState(0)
config = trainingConfig.TrainingConfig(args, state0.size, 5)
config.action_std = 0.1
ppo = []
for i in range(agentsInTeam * 2):
ppo.append(
PPO.PPO(
config.state_dim,
config.action_dim,
config.lr_actor,
config.lr_critic,
config.gamma,
config.K_epochs,
config.eps_clip,
)
)
# find latest model by datetime recurential waights ending with i.pth
all_models = os.listdir("models")
latest_model = None
for model in all_models:
if model.endswith(f"{i}.pth"):
if latest_model is None:
latest_model = model
else:
if model > latest_model:
latest_model = model
ppo[i].load(f"models/{latest_model}")
print(f"Loaded model {latest_model}")
# Main loop of the game
avg_reward = [0 for _ in range(agentsInTeam * 2)]
for _ in range(total_test_episodes):
for t in range(1, config.max_ep_len + 1):
for i in range(len(agentsInputs)):
state = gameController.getState(i)
action = ppo[i].select_action(state)
agentsInputs[i].movementDir.x = action[0]
agentsInputs[i].movementDir.y = action[1]
# agentsInputs[i].kickPos.x = action[2]
# agentsInputs[i].kickPos.y = action[3]
# agentsInputs[i].kick = True if action[4] > 0.5 else False
frameId += 1
# Update game state
# shouldClose = InputManager.parseUserInputs(gameController, agentsInputs[0])
gameController.nextFrame(agentsInputs)
for i in range(len(agentsInputs)):
reward = gameController.generateCurrentReward(i, phase)
avg_reward[i] += reward
# Log reward
if frameId % 1000 == 0:
print(f"Frame {frameId} - Agent {i} reward: {avg_reward[i] / 1000}")
config.writer.add_scalar(
f"Agent {i} reward", avg_reward[i] / 1000, frameId
)
avg_reward[i] = 0
gameController.reset()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="test")
args = parser.parse_args()
startUserGameplay(args)