forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodules.py
executable file
·226 lines (211 loc) · 7.7 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from utlis import *
import dgl.function as fn
from dgl.nn.functional import edge_softmax
class MSA(nn.Module):
# multi-head self-attention, three modes
# the first is the copy, determining which entity should be copied.
# the second is the normal attention with two sequence inputs
# the third is the attention but with one token and a sequence. (gather, attentive pooling)
def __init__(self, args, mode="normal"):
super(MSA, self).__init__()
if mode == "copy":
nhead, head_dim = 1, args.nhid
qninp, kninp = args.dec_ninp, args.nhid
if mode == "normal":
nhead, head_dim = args.nhead, args.head_dim
qninp, kninp = args.nhid, args.nhid
self.attn_drop = nn.Dropout(0.1)
self.WQ = nn.Linear(
qninp, nhead * head_dim, bias=True if mode == "copy" else False
)
if mode != "copy":
self.WK = nn.Linear(kninp, nhead * head_dim, bias=False)
self.WV = nn.Linear(kninp, nhead * head_dim, bias=False)
self.args, self.nhead, self.head_dim, self.mode = (
args,
nhead,
head_dim,
mode,
)
def forward(self, inp1, inp2, mask=None):
B, L2, H = inp2.shape
NH, HD = self.nhead, self.head_dim
if self.mode == "copy":
q, k, v = self.WQ(inp1), inp2, inp2
else:
q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2)
L1 = 1 if inp1.ndim == 2 else inp1.shape[1]
if self.mode != "copy":
q = q / math.sqrt(H)
q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3)
k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1)
v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3)
pre_attn = torch.matmul(q, k)
if mask is not None:
pre_attn = pre_attn.masked_fill(mask[:, None, None, :], -1e8)
if self.mode == "copy":
return pre_attn.squeeze(1)
else:
alpha = self.attn_drop(torch.softmax(pre_attn, -1))
attn = (
torch.matmul(alpha, v)
.permute(0, 2, 1, 3)
.contiguous()
.view(B, L1, NH * HD)
)
ret = attn
if inp1.ndim == 2:
return ret.squeeze(1)
else:
return ret
class BiLSTM(nn.Module):
# for entity encoding or the title encoding
def __init__(self, args, enc_type="title"):
super(BiLSTM, self).__init__()
self.enc_type = enc_type
self.drop = nn.Dropout(args.emb_drop)
self.bilstm = nn.LSTM(
args.nhid,
args.nhid // 2,
bidirectional=True,
num_layers=args.enc_lstm_layers,
batch_first=True,
)
def forward(self, inp, mask, ent_len=None):
inp = self.drop(inp)
lens = (mask == 0).sum(-1).long().tolist()
pad_seq = pack_padded_sequence(
inp, lens, batch_first=True, enforce_sorted=False
)
y, (_h, _c) = self.bilstm(pad_seq)
if self.enc_type == "title":
y = pad_packed_sequence(y, batch_first=True)[0]
return y
if self.enc_type == "entity":
_h = _h.transpose(0, 1).contiguous()
_h = _h[:, -2:].view(
_h.size(0), -1
) # two directions of the top-layer
ret = pad(_h.split(ent_len), out_type="tensor")
return ret
class GAT(nn.Module):
# a graph attention network with dot-product attention
def __init__(
self,
in_feats,
out_feats,
num_heads,
ffn_drop=0.0,
attn_drop=0.0,
trans=True,
):
super(GAT, self).__init__()
self._num_heads = num_heads
self._in_feats = in_feats
self._out_feats = out_feats
self.q_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.k_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.v_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.attn_drop = nn.Dropout(0.1)
self.ln1 = nn.LayerNorm(in_feats)
self.ln2 = nn.LayerNorm(in_feats)
if trans:
self.FFN = nn.Sequential(
nn.Linear(in_feats, 4 * in_feats),
nn.PReLU(4 * in_feats),
nn.Linear(4 * in_feats, in_feats),
nn.Dropout(0.1),
)
# a strange FFN, see the author's code
self._trans = trans
def forward(self, graph, feat):
graph = graph.local_var()
feat_c = feat.clone().detach().requires_grad_(False)
q, k, v = self.q_proj(feat), self.k_proj(feat_c), self.v_proj(feat_c)
q = q.view(-1, self._num_heads, self._out_feats)
k = k.view(-1, self._num_heads, self._out_feats)
v = v.view(-1, self._num_heads, self._out_feats)
graph.ndata.update(
{"ft": v, "el": k, "er": q}
) # k,q instead of q,k, the edge_softmax is applied on incoming edges
# compute edge attention
graph.apply_edges(fn.u_dot_v("el", "er", "e"))
e = graph.edata.pop("e") / math.sqrt(self._out_feats * self._num_heads)
graph.edata["a"] = edge_softmax(graph, e)
# message passing
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft2"))
rst = graph.ndata["ft2"]
# residual
rst = rst.view(feat.shape) + feat
if self._trans:
rst = self.ln1(rst)
rst = self.ln1(rst + self.FFN(rst))
# use the same layer norm, see the author's code
return rst
class GraphTrans(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
if args.graph_enc == "gat":
# we only support gtrans, don't use this one
self.gat = nn.ModuleList(
[
GAT(
args.nhid,
args.nhid // 4,
4,
attn_drop=args.attn_drop,
trans=False,
)
for _ in range(args.prop)
]
) # untested
else:
self.gat = nn.ModuleList(
[
GAT(
args.nhid,
args.nhid // 4,
4,
attn_drop=args.attn_drop,
ffn_drop=args.drop,
trans=True,
)
for _ in range(args.prop)
]
)
self.prop = args.prop
def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs):
device = ent.device
graphs = graphs.to(device)
ent_mask = ent_mask == 0 # reverse mask
rel_mask = rel_mask == 0
init_h = []
for i in range(graphs.batch_size):
init_h.append(ent[i][ent_mask[i]])
init_h.append(rel[i][rel_mask[i]])
init_h = torch.cat(init_h, 0)
feats = init_h
for i in range(self.prop):
feats = self.gat[i](graphs, feats)
g_root = feats.index_select(
0,
graphs.filter_nodes(
lambda x: x.data["type"] == NODE_TYPE["root"]
).to(device),
)
g_ent = pad(
feats.index_select(
0,
graphs.filter_nodes(
lambda x: x.data["type"] == NODE_TYPE["entity"]
).to(device),
).split(ent_len),
out_type="tensor",
)
return g_ent, g_root