-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_size.py
92 lines (73 loc) · 2.22 KB
/
plot_size.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import math
import json
import os
sns.set_theme(style="whitegrid", palette="colorblind")
output_dir = "out_size"
results = []
for filename in os.listdir(output_dir):
if not os.path.isdir(f"{output_dir}/{filename}"):
continue
try:
with open(f"{output_dir}/{filename}/config.json") as file:
config_dict = json.load(file)
with open(f"{output_dir}/{filename}/results.json") as file:
result_dict = json.load(file)
except:
print(f"Skipping directory {output_dir}/{filename}")
continue
results.append(config_dict | result_dict)
results_df = pd.DataFrame(results)
results_df.drop(
columns=[
"mean_discounted_returns",
"sem_discounted_returns",
"mean_policy_entropies",
],
inplace=True,
)
results_df.sort_values(
by=["env_name", "max_policy_updates", "learning_rate", "entropy_coef"], inplace=True
)
results_df.to_csv(f"{output_dir}/results_combined.csv", index=False)
env_names = [
"Frozenlake4x4",
"CartPole-v1",
"PendulumBangBang",
"CartPoleSwingup",
]
results_df = results_df[results_df["method"] == "dt"]
results_df["max_leaf_nodes"] = results_df["max_leaf_nodes"].map(
lambda x: f"$2^{int(math.log2(x))}$"
)
order = sorted(results_df["max_leaf_nodes"].unique())
for env_name in env_names:
env_df = results_df[results_df["env_name"] == env_name].copy()
if len(env_df) == 0:
continue
filename = f"{output_dir}/varying_size_{env_name}"
_, ax = plt.subplots(figsize=(4, 3))
maxes = env_df.groupby("max_leaf_nodes")["mean_return"].transform("max")
env_df["is_best"] = env_df["mean_return"] == maxes
env_df["max_return"] = maxes
sns.stripplot(
data=env_df,
x="max_leaf_nodes",
y="mean_return",
order=order,
jitter=False,
color="gray",
alpha=0.5,
ax=ax,
)
sns.pointplot(
data=env_df, x="max_leaf_nodes", y="max_return", order=order, ax=ax, zorder=100
)
plt.xlabel("number of leaves")
plt.ylabel("return")
plt.tight_layout()
plt.savefig(filename + ".png")
plt.savefig(filename + ".pdf")
plt.close()