Skip to content

Commit 264d96c

Browse files
yzh119UbuntuBarclayIIVoVAllen
authored
[bugfix] Fix a bunch of examples to be compatible with dgl 0.5 (dmlc#1957)
* upd * upd * upd * upd * upd * upd * fix pinsage also * upd * upd * upd Co-authored-by: Ubuntu <ubuntu@ip-172-31-29-3.us-east-2.compute.internal> Co-authored-by: Quan Gan <coin2028@hotmail.com> Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com>
1 parent dcf4641 commit 264d96c

25 files changed

+101
-118
lines changed

examples/mxnet/gcmc/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def forward(self, graph, ufeat, ifeat):
220220
for i in range(self._num_basis_functions):
221221
graph.nodes['user'].data['h'] = F.dot(ufeat, self.Ps[i].data())
222222
graph.apply_edges(fn.u_dot_v('h', 'h', 'sr'))
223-
basis_out.append(graph.edata['sr'].expand_dims(1))
223+
basis_out.append(graph.edata['sr'])
224224
out = F.concat(*basis_out, dim=1)
225225
out = self.rate_out(out)
226226
return out

examples/mxnet/gcn/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def main(args):
3636
else:
3737
cuda = True
3838
ctx = mx.gpu(args.gpu)
39-
g = g.to(ctx)
39+
g = g.int().to(ctx)
4040

4141
features = g.ndata['feat']
4242
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx)

examples/mxnet/graphsage/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def main(args):
6969
else:
7070
cuda = True
7171
ctx = mx.gpu(args.gpu)
72-
g = g.to(ctx)
72+
g = g.int().to(ctx)
7373

7474
features = g.ndata['feat']
7575
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx)
@@ -164,4 +164,4 @@ def main(args):
164164
args = parser.parse_args()
165165
print(args)
166166

167-
main(args)
167+
main(args)

examples/mxnet/sgc/sgc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def main(args):
3939
else:
4040
cuda = True
4141
ctx = mx.gpu(args.gpu)
42-
g = g.to(ctx)
42+
g = g.int().to(ctx)
4343

4444
features = g.ndata['feat']
4545
labels = mx.nd.array(g.ndata['label'], dtype="float32", ctx=ctx)
@@ -123,4 +123,4 @@ def main(args):
123123
args = parser.parse_args()
124124
print(args)
125125

126-
main(args)
126+
main(args)

examples/pytorch/cluster_gcn/cluster_gcn.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch.nn.functional as F
1212
import dgl
1313
from dgl.data import register_data_args
14-
from torch.utils.tensorboard import SummaryWriter
1514

1615
from modules import GraphSAGE
1716
from sampler import ClusterIter
@@ -84,7 +83,7 @@ def main(args):
8483
torch.cuda.set_device(args.gpu)
8584
val_mask = val_mask.cuda()
8685
test_mask = test_mask.cuda()
87-
g = g.to(args.gpu)
86+
g = g.int().to(args.gpu)
8887

8988
print('labels shape:', g.ndata['label'].shape)
9089
print("features shape, ", g.ndata['feat'].shape)
@@ -102,7 +101,6 @@ def main(args):
102101

103102
# logger and so on
104103
log_dir = save_log_dir(args)
105-
writer = SummaryWriter(log_dir)
106104
logger = Logger(os.path.join(log_dir, 'loggings'))
107105
logger.write(args)
108106

@@ -148,8 +146,6 @@ def main(args):
148146
if j % args.log_every == 0:
149147
print(f"epoch:{epoch}/{args.n_epochs}, Iteration {j}/"
150148
f"{len(cluster_iterator)}:training loss", loss.item())
151-
writer.add_scalar('train/loss', loss.item(),
152-
global_step=j + epoch * len(cluster_iterator))
153149
print("current memory:",
154150
torch.cuda.memory_allocated(device=pred.device) / 1024 / 1024)
155151

@@ -164,8 +160,6 @@ def main(args):
164160
print('new best val f1:', best_f1)
165161
torch.save(model.state_dict(), os.path.join(
166162
log_dir, 'best_model.pkl'))
167-
writer.add_scalar('val/f1-mic', val_f1_mic, global_step=epoch)
168-
writer.add_scalar('val/f1-mac', val_f1_mac, global_step=epoch)
169163

170164
end_time = time.time()
171165
print(f'training using time {start_time-end_time}')
@@ -177,8 +171,6 @@ def main(args):
177171
test_f1_mic, test_f1_mac = evaluate(
178172
model, g, labels, test_mask, multitask)
179173
print("Test F1-mic{:.4f}, Test F1-mac{:.4f}". format(test_f1_mic, test_f1_mac))
180-
writer.add_scalar('test/f1-mic', test_f1_mic)
181-
writer.add_scalar('test/f1-mac', test_f1_mac)
182174

183175
if __name__ == '__main__':
184176
parser = argparse.ArgumentParser(description='GCN')

examples/pytorch/gat/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def main(args):
5353
cuda = False
5454
else:
5555
cuda = True
56-
g = g.to(args.gpu)
56+
g = g.int().to(args.gpu)
5757

5858
features = g.ndata['feat']
5959
labels = g.ndata['label']

examples/pytorch/gat/train_ppi.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main(args):
6060
g = train_dataset[0]
6161
n_classes = train_dataset.num_labels
6262
num_feats = g.ndata['feat'].shape[1]
63-
g = g.to(device)
63+
g = g.int().to(device)
6464
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
6565
# define the model
6666
model = GAT(g,
@@ -117,12 +117,9 @@ def main(args):
117117
if cur_step == patience:
118118
break
119119
test_score_list = []
120-
for batch, test_data in enumerate(test_dataloader):
121-
subgraph, feats, labels = test_data
120+
for batch, subgraph in enumerate(test_dataloader):
122121
subgraph = subgraph.to(device)
123-
feats = feats.to(device)
124-
labels = labels.to(device)
125-
test_score_list.append(evaluate(feats, model, subgraph, labels.float(), loss_fcn)[0])
122+
test_score_list.append(evaluate(subgraph.ndata['feat'], model, subgraph, subgraph.ndata['label'], loss_fcn))
126123
print("Test F1-Score: {:.4f}".format(np.array(test_score_list).mean()))
127124

128125
if __name__ == '__main__':

examples/pytorch/gcmc/README.md

-17
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ ml-100k, no feature
2424
python3 train.py --data_name=ml-100k --use_one_hot_fea --gcn_agg_accum=stack
2525
```
2626
Results: RMSE=0.9088 (0.910 reported)
27-
Speed: 0.0410s/epoch (vanilla implementation: 0.1008s/epoch)
2827

2928
ml-100k, with feature
3029
```bash
@@ -37,7 +36,6 @@ ml-1m, no feature
3736
python3 train.py --data_name=ml-1m --gcn_agg_accum=sum --use_one_hot_fea
3837
```
3938
Results: RMSE=0.8377 (0.832 reported)
40-
Speed: 0.0844s/epoch (vanilla implementation: 1.538s/epoch)
4139

4240
ml-10m, no feature
4341
```bash
@@ -46,7 +44,6 @@ python3 train.py --data_name=ml-10m --gcn_agg_accum=stack --gcn_dropout=0.3 \
4644
--use_one_hot_fea --gen_r_num_basis_func=4
4745
```
4846
Results: RMSE=0.7800 (0.777 reported)
49-
Speed: 1.1982/epoch (vanilla implementation: OOM)
5047
Testbed: EC2 p3.2xlarge instance(Amazon Linux 2)
5148

5249
### Train with minibatch on a single GPU
@@ -67,8 +64,6 @@ python3 train_sampling.py --data_name=ml-100k \
6764
--gpu 0
6865
```
6966
Results: RMSE=0.9380
70-
Speed: 1.059s/epoch (Run with 70 epoches)
71-
Speed: 1.046s/epoch (mix_cpu_gpu)
7267

7368
ml-100k, with feature
7469
```bash
@@ -97,8 +92,6 @@ python3 train_sampling.py --data_name=ml-1m \
9792
--gpu 0
9893
```
9994
Results: RMSE=0.8632
100-
Speed: 7.852s/epoch (Run with 60 epoches)
101-
Speed: 7.788s/epoch (mix_cpu_gpu)
10295

10396
ml-10m, no feature
10497
```bash
@@ -126,8 +119,6 @@ python3 train_sampling.py --data_name=ml-10m \
126119
--gpu 0
127120
```
128121
Results: RMSE=0.8050
129-
Speed: 394.304s/epoch (Run with 60 epoches)
130-
Speed: 408.749s/epoch (mix_cpu_gpu)
131122
Testbed: EC2 p3.2xlarge instance
132123

133124
### Train with minibatch on multi-GPU
@@ -151,8 +142,6 @@ python train_sampling.py --data_name=ml-100k \
151142
--gpu 0,1,2,3,4,5,6,7
152143
```
153144
Result: RMSE=0.9397
154-
Speed: 1.202s/epoch (Run with only 30 epoches)
155-
Speed: 1.245/epoch (mix_cpu_gpu)
156145

157146
ml-100k, with feature
158147
```bash
@@ -162,7 +151,6 @@ python train_sampling.py --data_name=ml-100k \
162151
--gpu 0,1,2,3,4,5,6,7
163152
```
164153
Result: RMSE=0.9655
165-
Speed: 1.265/epoch (Run with 30 epoches)
166154

167155
ml-1m, no feature
168156
```bash
@@ -182,8 +170,6 @@ python train_sampling.py --data_name=ml-1m \
182170
--gpu 0,1,2,3,4,5,6,7
183171
```
184172
Results: RMSE=0.8621
185-
Speed: 11.612s/epoch (Run with 40 epoches)
186-
Speed: 12.483s/epoch (mix_cpu_gpu)
187173

188174
ml-10m, no feature
189175
```bash
@@ -211,8 +197,6 @@ python train_sampling.py --data_name=ml-10m \
211197
--gpu 0,1,2,3,4,5,6,7
212198
```
213199
Results: RMSE=0.8084
214-
Speed: 632.868s/epoch (Run with 30 epoches)
215-
Speed: 633.397s/epoch (mix_cpu_gpu)
216200
Testbed: EC2 p3.16xlarge instance
217201

218202
### Train with minibatch on CPU
@@ -223,5 +207,4 @@ python3 train_sampling.py --data_name=ml-100k \
223207
--gcn_agg_accum=stack \
224208
--gpu -1
225209
```
226-
Speed 1.591s/epoch
227210
Testbed: EC2 r5.xlarge instance

examples/pytorch/gcmc/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def forward(self, graph, ufeat, ifeat):
340340
for i in range(self._num_basis):
341341
graph.nodes['user'].data['h'] = ufeat @ self.Ps[i]
342342
graph.apply_edges(fn.u_dot_v('h', 'h', 'sr'))
343-
basis_out.append(graph.edata['sr'].unsqueeze(1))
343+
basis_out.append(graph.edata['sr'])
344344
out = th.cat(basis_out, dim=1)
345345
out = self.combine_basis(out)
346346
return out

examples/pytorch/gcmc/train.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,12 @@ def train(args):
105105
count_num = 0
106106
count_loss = 0
107107

108-
dataset.train_enc_graph = dataset.train_enc_graph.to(args.device)
109-
dataset.train_dec_graph = dataset.train_dec_graph.to(args.device)
108+
dataset.train_enc_graph = dataset.train_enc_graph.int().to(args.device)
109+
dataset.train_dec_graph = dataset.train_dec_graph.int().to(args.device)
110110
dataset.valid_enc_graph = dataset.train_enc_graph
111-
dataset.valid_dec_graph = dataset.valid_dec_graph.to(args.device)
112-
dataset.test_enc_graph = dataset.test_enc_graph.to(args.device)
113-
dataset.test_dec_graph = dataset.test_dec_graph.to(args.device)
111+
dataset.valid_dec_graph = dataset.valid_dec_graph.int().to(args.device)
112+
dataset.test_enc_graph = dataset.test_enc_graph.int().to(args.device)
113+
dataset.test_dec_graph = dataset.test_dec_graph.int().to(args.device)
114114

115115
print("Start training ...")
116116
dur = []

examples/pytorch/gcn/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def main(args):
3838
cuda = False
3939
else:
4040
cuda = True
41-
g = g.to(args.gpu)
41+
g = g.int().to(args.gpu)
4242

4343
features = g.ndata['feat']
4444
labels = g.ndata['label']

examples/pytorch/graphsage/train_cv.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def inference(self, g, x, batch_size, device):
109109
end = start + batch_size
110110
batch_nodes = nodes[start:end]
111111
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
112-
block = block.to(device)
112+
block = block.int().to(device)
113113
induced_nodes = block.srcdata[dgl.NID]
114114

115115
h = x[induced_nodes].to(device)
@@ -188,7 +188,7 @@ def load_subtensor(g, labels, blocks, hist_blocks, dev_id, aggregation_on_device
188188
hist_block = hist_block.to(dev_id)
189189
hist_block.update_all(fn.copy_u('hist', 'm'), fn.mean('m', 'agg_hist'))
190190

191-
block = block.to(dev_id)
191+
block = block.int().to(dev_id)
192192
if not aggregation_on_device:
193193
hist_block = hist_block.to(dev_id)
194194
block.dstdata['agg_hist'] = hist_block.dstdata['agg_hist']
@@ -220,8 +220,8 @@ def run(args, dev_id, data):
220220

221221
# Unpack data
222222
train_mask, val_mask, in_feats, labels, n_classes, g = data
223-
train_nid = train_mask.nonzero()[:, 0]
224-
val_nid = val_mask.nonzero()[:, 0]
223+
train_nid = train_mask.nonzero().squeeze()
224+
val_nid = val_mask.nonzero().squeeze()
225225

226226
# Create sampler
227227
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')])

examples/pytorch/graphsage/train_cv_multi_gpu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ def run(proc_id, n_gpus, args, devices, data):
262262

263263
# Unpack data
264264
train_mask, val_mask, in_feats, labels, n_classes, g = data
265-
train_nid = train_mask.nonzero()[:, 0]
266-
val_nid = val_mask.nonzero()[:, 0]
265+
train_nid = train_mask.nonzero().squeeze()
266+
val_nid = val_mask.nonzero().squeeze()
267267

268268
# Split train_nid
269269
train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id]

examples/pytorch/graphsage/train_full.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313
import torch.nn.functional as F
14+
import dgl
1415
from dgl import DGLGraph
1516
from dgl.data import register_data_args, load_data
1617
from dgl.nn.pytorch.conv import SAGEConv
@@ -48,31 +49,27 @@ def forward(self, graph, inputs):
4849
return h
4950

5051

51-
def evaluate(model, graph, features, labels, mask):
52+
def evaluate(model, graph, features, labels, nid):
5253
model.eval()
5354
with torch.no_grad():
5455
logits = model(graph, features)
55-
logits = logits[mask]
56-
labels = labels[mask]
56+
logits = logits[nid]
57+
labels = labels[nid]
5758
_, indices = torch.max(logits, dim=1)
5859
correct = torch.sum(indices == labels)
5960
return correct.item() * 1.0 / len(labels)
6061

6162
def main(args):
6263
# load and preprocess dataset
6364
data = load_data(args)
64-
features = torch.FloatTensor(data.features)
65-
labels = torch.LongTensor(data.labels)
66-
if hasattr(torch, 'BoolTensor'):
67-
train_mask = torch.BoolTensor(data.train_mask)
68-
val_mask = torch.BoolTensor(data.val_mask)
69-
test_mask = torch.BoolTensor(data.test_mask)
70-
else:
71-
train_mask = torch.ByteTensor(data.train_mask)
72-
val_mask = torch.ByteTensor(data.val_mask)
73-
test_mask = torch.ByteTensor(data.test_mask)
65+
g = data[0]
66+
features = g.ndata['feat']
67+
labels = g.ndata['label']
68+
train_mask = g.ndata['train_mask']
69+
val_mask = g.ndata['val_mask']
70+
test_mask = g.ndata['test_mask']
7471
in_feats = features.shape[1]
75-
n_classes = data.num_labels
72+
n_classes = data.num_classes
7673
n_edges = data.graph.number_of_edges()
7774
print("""----Data statistics------'
7875
#Edges %d
@@ -97,11 +94,15 @@ def main(args):
9794
test_mask = test_mask.cuda()
9895
print("use cuda:", args.gpu)
9996

97+
train_nid = train_mask.nonzero().squeeze()
98+
val_nid = val_mask.nonzero().squeeze()
99+
test_nid = test_mask.nonzero().squeeze()
100+
100101
# graph preprocess and calculate normalization factor
101-
g = data.graph
102-
g.remove_edges_from(nx.selfloop_edges(g))
103-
g = DGLGraph(g)
102+
g = dgl.remove_self_loop(g)
104103
n_edges = g.number_of_edges()
104+
if cuda:
105+
g = g.int().to(args.gpu)
105106

106107
# create GraphSAGE model
107108
model = GraphSAGE(in_feats,
@@ -126,7 +127,7 @@ def main(args):
126127
t0 = time.time()
127128
# forward
128129
logits = model(g, features)
129-
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
130+
loss = F.cross_entropy(logits[train_nid], labels[train_nid])
130131

131132
optimizer.zero_grad()
132133
loss.backward()
@@ -135,13 +136,13 @@ def main(args):
135136
if epoch >= 3:
136137
dur.append(time.time() - t0)
137138

138-
acc = evaluate(model, g, features, labels, val_mask)
139+
acc = evaluate(model, g, features, labels, val_nid)
139140
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
140141
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
141142
acc, n_edges / np.mean(dur) / 1000))
142143

143144
print()
144-
acc = evaluate(model, g, features, labels, test_mask)
145+
acc = evaluate(model, g, features, labels, test_nid)
145146
print("Test Accuracy {:.4f}".format(acc))
146147

147148

0 commit comments

Comments
 (0)