forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
153 lines (126 loc) · 4.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from sklearn.metrics import f1_score
from utils import EarlyStopping, load_data
def score(logits, labels):
_, indices = torch.max(logits, dim=1)
prediction = indices.long().cpu().numpy()
labels = labels.cpu().numpy()
accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average="micro")
macro_f1 = f1_score(labels, prediction, average="macro")
return accuracy, micro_f1, macro_f1
def evaluate(model, g, features, labels, mask, loss_func):
model.eval()
with torch.no_grad():
logits = model(g, features)
loss = loss_func(logits[mask], labels[mask])
accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])
return loss, accuracy, micro_f1, macro_f1
def main(args):
# If args['hetero'] is True, g would be a heterogeneous graph.
# Otherwise, it will be a list of homogeneous graphs.
(
g,
features,
labels,
num_classes,
train_idx,
val_idx,
test_idx,
train_mask,
val_mask,
test_mask,
) = load_data(args["dataset"])
if hasattr(torch, "BoolTensor"):
train_mask = train_mask.bool()
val_mask = val_mask.bool()
test_mask = test_mask.bool()
features = features.to(args["device"])
labels = labels.to(args["device"])
train_mask = train_mask.to(args["device"])
val_mask = val_mask.to(args["device"])
test_mask = test_mask.to(args["device"])
if args["hetero"]:
from model_hetero import HAN
model = HAN(
meta_paths=[["pa", "ap"], ["pf", "fp"]],
in_size=features.shape[1],
hidden_size=args["hidden_units"],
out_size=num_classes,
num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
g = g.to(args["device"])
else:
from model import HAN
model = HAN(
num_meta_paths=len(g),
in_size=features.shape[1],
hidden_size=args["hidden_units"],
out_size=num_classes,
num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
g = [graph.to(args["device"]) for graph in g]
stopper = EarlyStopping(patience=args["patience"])
loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)
for epoch in range(args["num_epochs"]):
model.train()
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc, train_micro_f1, train_macro_f1 = score(
logits[train_mask], labels[train_mask]
)
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(
model, g, features, labels, val_mask, loss_fcn
)
early_stop = stopper.step(val_loss.data.item(), val_acc, model)
print(
"Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | "
"Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}".format(
epoch + 1,
loss.item(),
train_micro_f1,
train_macro_f1,
val_loss.item(),
val_micro_f1,
val_macro_f1,
)
)
if early_stop:
break
stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(
model, g, features, labels, test_mask, loss_fcn
)
print(
"Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}".format(
test_loss.item(), test_micro_f1, test_macro_f1
)
)
if __name__ == "__main__":
import argparse
from utils import setup
parser = argparse.ArgumentParser("HAN")
parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed")
parser.add_argument(
"-ld",
"--log-dir",
type=str,
default="results",
help="Dir for saving training results",
)
parser.add_argument(
"--hetero",
action="store_true",
help="Use metapath coalescing with DGL's own dataset",
)
args = parser.parse_args().__dict__
args = setup(args)
main(args)