Skip to content

Commit f19f05c

Browse files
frozenbugsSteve
and
Steve
authored
[Misc] Black auto fix. (dmlc#4651)
Co-authored-by: Steve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
1 parent 977b1ba commit f19f05c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+9461
-4790
lines changed
+82-45
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,98 @@
1-
from torch.utils.data import Dataset, DataLoader
1+
import collections
2+
3+
from torch.utils.data import DataLoader, Dataset
4+
25
import dgl
36
from dgl.data import PPIDataset
4-
import collections
57

6-
#implement the collate_fn for dgl graph data class
7-
PPIBatch = collections.namedtuple('PPIBatch', ['graph', 'label'])
8+
# implement the collate_fn for dgl graph data class
9+
PPIBatch = collections.namedtuple("PPIBatch", ["graph", "label"])
10+
11+
812
def batcher(device):
913
def batcher_dev(batch):
1014
batch_graphs = dgl.batch(batch)
11-
return PPIBatch(graph=batch_graphs,
12-
label=batch_graphs.ndata['label'].to(device))
15+
return PPIBatch(
16+
graph=batch_graphs, label=batch_graphs.ndata["label"].to(device)
17+
)
18+
1319
return batcher_dev
1420

15-
#add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders
16-
def load_PPI(batch_size=1, device='cpu'):
17-
train_set = PPIDataset(mode='train')
18-
valid_set = PPIDataset(mode='valid')
19-
test_set = PPIDataset(mode='test')
20-
#for each graph, add self-loops as a new relation type
21-
#here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed
21+
22+
# add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders
23+
def load_PPI(batch_size=1, device="cpu"):
24+
train_set = PPIDataset(mode="train")
25+
valid_set = PPIDataset(mode="valid")
26+
test_set = PPIDataset(mode="test")
27+
# for each graph, add self-loops as a new relation type
28+
# here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed
2229
for i in range(len(train_set)):
23-
g = dgl.heterograph({
24-
('_N','_E','_N'): train_set[i].edges(),
25-
('_N', 'self', '_N'): (train_set[i].nodes(), train_set[i].nodes())
26-
})
27-
g.ndata['label'] = train_set[i].ndata['label']
28-
g.ndata['feat'] = train_set[i].ndata['feat']
29-
g.ndata['_ID'] = train_set[i].ndata['_ID']
30-
g.edges['_E'].data['_ID'] = train_set[i].edata['_ID']
30+
g = dgl.heterograph(
31+
{
32+
("_N", "_E", "_N"): train_set[i].edges(),
33+
("_N", "self", "_N"): (
34+
train_set[i].nodes(),
35+
train_set[i].nodes(),
36+
),
37+
}
38+
)
39+
g.ndata["label"] = train_set[i].ndata["label"]
40+
g.ndata["feat"] = train_set[i].ndata["feat"]
41+
g.ndata["_ID"] = train_set[i].ndata["_ID"]
42+
g.edges["_E"].data["_ID"] = train_set[i].edata["_ID"]
3143
train_set.graphs[i] = g
3244
for i in range(len(valid_set)):
33-
g = dgl.heterograph({
34-
('_N','_E','_N'): valid_set[i].edges(),
35-
('_N', 'self', '_N'): (valid_set[i].nodes(), valid_set[i].nodes())
36-
})
37-
g.ndata['label'] = valid_set[i].ndata['label']
38-
g.ndata['feat'] = valid_set[i].ndata['feat']
39-
g.ndata['_ID'] = valid_set[i].ndata['_ID']
40-
g.edges['_E'].data['_ID'] = valid_set[i].edata['_ID']
41-
valid_set.graphs[i] = g
45+
g = dgl.heterograph(
46+
{
47+
("_N", "_E", "_N"): valid_set[i].edges(),
48+
("_N", "self", "_N"): (
49+
valid_set[i].nodes(),
50+
valid_set[i].nodes(),
51+
),
52+
}
53+
)
54+
g.ndata["label"] = valid_set[i].ndata["label"]
55+
g.ndata["feat"] = valid_set[i].ndata["feat"]
56+
g.ndata["_ID"] = valid_set[i].ndata["_ID"]
57+
g.edges["_E"].data["_ID"] = valid_set[i].edata["_ID"]
58+
valid_set.graphs[i] = g
4259
for i in range(len(test_set)):
43-
g = dgl.heterograph({
44-
('_N','_E','_N'): test_set[i].edges(),
45-
('_N', 'self', '_N'): (test_set[i].nodes(), test_set[i].nodes())
46-
})
47-
g.ndata['label'] = test_set[i].ndata['label']
48-
g.ndata['feat'] = test_set[i].ndata['feat']
49-
g.ndata['_ID'] = test_set[i].ndata['_ID']
50-
g.edges['_E'].data['_ID'] = test_set[i].edata['_ID']
51-
test_set.graphs[i] = g
60+
g = dgl.heterograph(
61+
{
62+
("_N", "_E", "_N"): test_set[i].edges(),
63+
("_N", "self", "_N"): (
64+
test_set[i].nodes(),
65+
test_set[i].nodes(),
66+
),
67+
}
68+
)
69+
g.ndata["label"] = test_set[i].ndata["label"]
70+
g.ndata["feat"] = test_set[i].ndata["feat"]
71+
g.ndata["_ID"] = test_set[i].ndata["_ID"]
72+
g.edges["_E"].data["_ID"] = test_set[i].edata["_ID"]
73+
test_set.graphs[i] = g
5274

5375
etypes = train_set[0].etypes
54-
in_size = train_set[0].ndata['feat'].shape[1]
55-
out_size = train_set[0].ndata['label'].shape[1]
76+
in_size = train_set[0].ndata["feat"].shape[1]
77+
out_size = train_set[0].ndata["label"].shape[1]
5678

57-
#prepare train, valid, and test dataloaders
58-
train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
59-
valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
60-
test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
79+
# prepare train, valid, and test dataloaders
80+
train_loader = DataLoader(
81+
train_set,
82+
batch_size=batch_size,
83+
collate_fn=batcher(device),
84+
shuffle=True,
85+
)
86+
valid_loader = DataLoader(
87+
valid_set,
88+
batch_size=batch_size,
89+
collate_fn=batcher(device),
90+
shuffle=True,
91+
)
92+
test_loader = DataLoader(
93+
test_set,
94+
batch_size=batch_size,
95+
collate_fn=batcher(device),
96+
shuffle=True,
97+
)
6198
return train_loader, valid_loader, test_loader, etypes, in_size, out_size

0 commit comments

Comments
 (0)