-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfairn2v.py
181 lines (150 loc) · 6.23 KB
/
fairn2v.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from sklearn.linear_model import LogisticRegression
from sklearn import metrics
from torch_geometric.utils.num_nodes import maybe_num_nodes
try:
import torch_cluster # noqa
random_walk = torch.ops.torch_cluster.random_walk
except ImportError:
random_walk = None
EPS = 1e-15
class Node2Vec(torch.nn.Module):
r"""The Node2Vec model from the
`"node2vec: Scalable Feature Learning for Networks"
<https://arxiv.org/abs/1607.00653>`_ paper where random walks of
length :obj:`walk_length` are sampled in a given graph, and node embeddings
are learned via negative sampling optimization.
.. note::
For an example of using Node2Vec, see `examples/node2vec.py
<https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
node2vec.py>`_.
Args:
edge_index (LongTensor): The edge indices.
embedding_dim (int): The size of each embedding vector.
walk_length (int): The walk length.
context_size (int): The actual context size which is considered for
positive samples. This parameter increases the effective sampling
rate by reusing samples across different source nodes.
walks_per_node (int, optional): The number of walks to sample for each
node. (default: :obj:`1`)
p (float, optional): Likelihood of immediately revisiting a node in the
walk. (default: :obj:`1`)
q (float, optional): Control parameter to interpolate between
breadth-first strategy and depth-first strategy (default: :obj:`1`)
num_negative_samples (int, optional): The number of negative samples to
use for each positive sample. (default: :obj:`1`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
weight matrix will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
edge_index,
embedding_dim,
walk_length,
context_size,
walks_per_node=1,
p=1,
q=1,
num_negative_samples=1,
num_nodes=None,
sparse=False,
):
super(Node2Vec, self).__init__()
if random_walk is None:
raise ImportError("`Node2Vec` requires `torch-cluster`.")
N = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index
self.adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
self.adj = self.adj.to("cpu")
assert walk_length >= context_size
self.embedding_dim = embedding_dim
self.walk_length = walk_length - 1
self.context_size = context_size
self.walks_per_node = walks_per_node
self.p = p
self.q = q
self.num_negative_samples = num_negative_samples
self.embedding = Embedding(N, embedding_dim, sparse=sparse)
self.reset_parameters()
def reset_parameters(self):
self.embedding.reset_parameters()
def forward(self, batch=None):
"""Returns the embeddings for the nodes in :obj:`batch`."""
emb = self.embedding.weight
return emb if batch is None else emb[batch]
def loader(self, **kwargs):
return DataLoader(
range(self.adj.sparse_size(0)), collate_fn=self.sample, **kwargs
)
def pos_sample(self, batch):
batch = batch.repeat(self.walks_per_node)
rowptr, col, _ = self.adj.csr()
rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q)
if not isinstance(rw, torch.Tensor):
rw = rw[0]
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j : j + self.context_size])
return torch.cat(walks, dim=0)
def neg_sample(self, batch):
batch = batch.repeat(self.walks_per_node * self.num_negative_samples)
rw = torch.randint(self.adj.sparse_size(0), (batch.size(0), self.walk_length))
rw = torch.cat([batch.view(-1, 1), rw], dim=-1)
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j : j + self.context_size])
return torch.cat(walks, dim=0)
def sample(self, batch):
if not isinstance(batch, torch.Tensor):
batch = torch.tensor(batch)
return self.pos_sample(batch), self.neg_sample(batch)
def loss(self, pos_rw, neg_rw):
r"""Computes the loss given positive and negative random walks."""
# Positive loss.
start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()
h_start = self.embedding(start).view(pos_rw.size(0), 1, self.embedding_dim)
h_rest = self.embedding(rest.view(-1)).view(
pos_rw.size(0), -1, self.embedding_dim
)
out = (h_start * h_rest).sum(dim=-1).view(-1)
pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()
# Negative loss.
start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()
h_start = self.embedding(start).view(neg_rw.size(0), 1, self.embedding_dim)
h_rest = self.embedding(rest.view(-1)).view(
neg_rw.size(0), -1, self.embedding_dim
)
out = (h_start * h_rest).sum(dim=-1).view(-1)
neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()
return pos_loss + neg_loss
def test(
self,
train_z,
train_y,
test_z,
test_y,
solver="lbfgs",
multi_class="auto",
*args,
**kwargs
):
r"""Evaluates latent space quality via a logistic regression downstream
task."""
clf = LogisticRegression(
solver=solver, multi_class=multi_class, *args, **kwargs
).fit(train_z, train_y)
acc = clf.score(test_z, test_y)
auc = metrics.roc_auc_score(test_y, clf.predict_proba(test_z)[:, 1])
return acc, auc
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__,
self.embedding.weight.size(0),
self.embedding.weight.size(1),
)