forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
222 lines (193 loc) · 7.65 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import math
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.functional import edge_softmax
class HGTLayer(nn.Module):
def __init__(
self,
in_dim,
out_dim,
node_dict,
edge_dict,
n_heads,
dropout=0.2,
use_norm=False,
):
super(HGTLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.node_dict = node_dict
self.edge_dict = edge_dict
self.num_types = len(node_dict)
self.num_relations = len(edge_dict)
self.total_rel = self.num_types * self.num_relations * self.num_types
self.n_heads = n_heads
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
self.att = None
self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList()
self.use_norm = use_norm
for t in range(self.num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim))
self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(
torch.ones(self.num_relations, self.n_heads)
)
self.relation_att = nn.Parameter(
torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
)
self.relation_msg = nn.Parameter(
torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
)
self.skip = nn.Parameter(torch.ones(self.num_types))
self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg)
def forward(self, G, h):
with G.local_scope():
node_dict, edge_dict = self.node_dict, self.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
sub_graph = G[srctype, etype, dsttype]
k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
e_id = self.edge_dict[etype]
relation_att = self.relation_att[e_id]
relation_pri = self.relation_pri[e_id]
relation_msg = self.relation_msg[e_id]
k = torch.einsum("bij,ijk->bik", k, relation_att)
v = torch.einsum("bij,ijk->bik", v, relation_msg)
sub_graph.srcdata["k"] = k
sub_graph.dstdata["q"] = q
sub_graph.srcdata["v_%d" % e_id] = v
sub_graph.apply_edges(fn.v_dot_u("q", "k", "t"))
attn_score = (
sub_graph.edata.pop("t").sum(-1)
* relation_pri
/ self.sqrt_dk
)
attn_score = edge_softmax(sub_graph, attn_score, norm_by="dst")
sub_graph.edata["t"] = attn_score.unsqueeze(-1)
G.multi_update_all(
{
etype: (
fn.u_mul_e("v_%d" % e_id, "t", "m"),
fn.sum("m", "t"),
)
for etype, e_id in edge_dict.items()
},
cross_reducer="mean",
)
new_h = {}
for ntype in G.ntypes:
"""
Step 3: Target-specific Aggregation
x = norm( W[node_type] * gelu( Agg(x) ) + x )
"""
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
t = G.nodes[ntype].data["t"].view(-1, self.out_dim)
trans_out = self.drop(self.a_linears[n_id](t))
trans_out = trans_out * alpha + h[ntype] * (1 - alpha)
if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out)
else:
new_h[ntype] = trans_out
return new_h
class HGT(nn.Module):
def __init__(
self,
G,
node_dict,
edge_dict,
n_inp,
n_hid,
n_out,
n_layers,
n_heads,
use_norm=True,
):
super(HGT, self).__init__()
self.node_dict = node_dict
self.edge_dict = edge_dict
self.gcs = nn.ModuleList()
self.n_inp = n_inp
self.n_hid = n_hid
self.n_out = n_out
self.n_layers = n_layers
self.adapt_ws = nn.ModuleList()
for t in range(len(node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers):
self.gcs.append(
HGTLayer(
n_hid,
n_hid,
node_dict,
edge_dict,
n_heads,
use_norm=use_norm,
)
)
self.out = nn.Linear(n_hid, n_out)
def forward(self, G, out_key):
h = {}
for ntype in G.ntypes:
n_id = self.node_dict[ntype]
h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data["inp"]))
for i in range(self.n_layers):
h = self.gcs[i](G, h)
return self.out(h[out_key])
class HeteroRGCNLayer(nn.Module):
def __init__(self, in_size, out_size, etypes):
super(HeteroRGCNLayer, self).__init__()
# W_r for each relation
self.weight = nn.ModuleDict(
{name: nn.Linear(in_size, out_size) for name in etypes}
)
def forward(self, G, feat_dict):
# The input is a dictionary of node features for each type
funcs = {}
for srctype, etype, dsttype in G.canonical_etypes:
# Compute W_r * h
Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing
G.nodes[srctype].data["Wh_%s" % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func).
# Note that the results are saved to the same destination feature 'h', which
# hints the type wise reducer for aggregation.
funcs[etype] = (fn.copy_u("Wh_%s" % etype, "m"), fn.mean("m", "h"))
# Trigger message passing of multiple types.
# The first argument is the message passing functions for each relation.
# The second one is the type wise reducer, could be "sum", "max",
# "min", "mean", "stack"
G.multi_update_all(funcs, "sum")
# return the updated node feature dictionary
return {ntype: G.nodes[ntype].data["h"] for ntype in G.ntypes}
class HeteroRGCN(nn.Module):
def __init__(self, G, in_size, hidden_size, out_size):
super(HeteroRGCN, self).__init__()
# create layers
self.layer1 = HeteroRGCNLayer(in_size, hidden_size, G.etypes)
self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes)
def forward(self, G, out_key):
input_dict = {ntype: G.nodes[ntype].data["inp"] for ntype in G.ntypes}
h_dict = self.layer1(G, input_dict)
h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()}
h_dict = self.layer2(G, h_dict)
# get appropriate logits
return h_dict[out_key]