|
1 |
| -from torch.utils.data import Dataset, DataLoader |
| 1 | +import collections |
| 2 | + |
| 3 | +from torch.utils.data import DataLoader, Dataset |
| 4 | + |
2 | 5 | import dgl
|
3 | 6 | from dgl.data import PPIDataset
|
4 |
| -import collections |
5 | 7 |
|
6 |
| -#implement the collate_fn for dgl graph data class |
7 |
| -PPIBatch = collections.namedtuple('PPIBatch', ['graph', 'label']) |
| 8 | +# implement the collate_fn for dgl graph data class |
| 9 | +PPIBatch = collections.namedtuple("PPIBatch", ["graph", "label"]) |
| 10 | + |
| 11 | + |
8 | 12 | def batcher(device):
|
9 | 13 | def batcher_dev(batch):
|
10 | 14 | batch_graphs = dgl.batch(batch)
|
11 |
| - return PPIBatch(graph=batch_graphs, |
12 |
| - label=batch_graphs.ndata['label'].to(device)) |
| 15 | + return PPIBatch( |
| 16 | + graph=batch_graphs, label=batch_graphs.ndata["label"].to(device) |
| 17 | + ) |
| 18 | + |
13 | 19 | return batcher_dev
|
14 | 20 |
|
15 |
| -#add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders |
16 |
| -def load_PPI(batch_size=1, device='cpu'): |
17 |
| - train_set = PPIDataset(mode='train') |
18 |
| - valid_set = PPIDataset(mode='valid') |
19 |
| - test_set = PPIDataset(mode='test') |
20 |
| - #for each graph, add self-loops as a new relation type |
21 |
| - #here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed |
| 21 | + |
| 22 | +# add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders |
| 23 | +def load_PPI(batch_size=1, device="cpu"): |
| 24 | + train_set = PPIDataset(mode="train") |
| 25 | + valid_set = PPIDataset(mode="valid") |
| 26 | + test_set = PPIDataset(mode="test") |
| 27 | + # for each graph, add self-loops as a new relation type |
| 28 | + # here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed |
22 | 29 | for i in range(len(train_set)):
|
23 |
| - g = dgl.heterograph({ |
24 |
| - ('_N','_E','_N'): train_set[i].edges(), |
25 |
| - ('_N', 'self', '_N'): (train_set[i].nodes(), train_set[i].nodes()) |
26 |
| - }) |
27 |
| - g.ndata['label'] = train_set[i].ndata['label'] |
28 |
| - g.ndata['feat'] = train_set[i].ndata['feat'] |
29 |
| - g.ndata['_ID'] = train_set[i].ndata['_ID'] |
30 |
| - g.edges['_E'].data['_ID'] = train_set[i].edata['_ID'] |
| 30 | + g = dgl.heterograph( |
| 31 | + { |
| 32 | + ("_N", "_E", "_N"): train_set[i].edges(), |
| 33 | + ("_N", "self", "_N"): ( |
| 34 | + train_set[i].nodes(), |
| 35 | + train_set[i].nodes(), |
| 36 | + ), |
| 37 | + } |
| 38 | + ) |
| 39 | + g.ndata["label"] = train_set[i].ndata["label"] |
| 40 | + g.ndata["feat"] = train_set[i].ndata["feat"] |
| 41 | + g.ndata["_ID"] = train_set[i].ndata["_ID"] |
| 42 | + g.edges["_E"].data["_ID"] = train_set[i].edata["_ID"] |
31 | 43 | train_set.graphs[i] = g
|
32 | 44 | for i in range(len(valid_set)):
|
33 |
| - g = dgl.heterograph({ |
34 |
| - ('_N','_E','_N'): valid_set[i].edges(), |
35 |
| - ('_N', 'self', '_N'): (valid_set[i].nodes(), valid_set[i].nodes()) |
36 |
| - }) |
37 |
| - g.ndata['label'] = valid_set[i].ndata['label'] |
38 |
| - g.ndata['feat'] = valid_set[i].ndata['feat'] |
39 |
| - g.ndata['_ID'] = valid_set[i].ndata['_ID'] |
40 |
| - g.edges['_E'].data['_ID'] = valid_set[i].edata['_ID'] |
41 |
| - valid_set.graphs[i] = g |
| 45 | + g = dgl.heterograph( |
| 46 | + { |
| 47 | + ("_N", "_E", "_N"): valid_set[i].edges(), |
| 48 | + ("_N", "self", "_N"): ( |
| 49 | + valid_set[i].nodes(), |
| 50 | + valid_set[i].nodes(), |
| 51 | + ), |
| 52 | + } |
| 53 | + ) |
| 54 | + g.ndata["label"] = valid_set[i].ndata["label"] |
| 55 | + g.ndata["feat"] = valid_set[i].ndata["feat"] |
| 56 | + g.ndata["_ID"] = valid_set[i].ndata["_ID"] |
| 57 | + g.edges["_E"].data["_ID"] = valid_set[i].edata["_ID"] |
| 58 | + valid_set.graphs[i] = g |
42 | 59 | for i in range(len(test_set)):
|
43 |
| - g = dgl.heterograph({ |
44 |
| - ('_N','_E','_N'): test_set[i].edges(), |
45 |
| - ('_N', 'self', '_N'): (test_set[i].nodes(), test_set[i].nodes()) |
46 |
| - }) |
47 |
| - g.ndata['label'] = test_set[i].ndata['label'] |
48 |
| - g.ndata['feat'] = test_set[i].ndata['feat'] |
49 |
| - g.ndata['_ID'] = test_set[i].ndata['_ID'] |
50 |
| - g.edges['_E'].data['_ID'] = test_set[i].edata['_ID'] |
51 |
| - test_set.graphs[i] = g |
| 60 | + g = dgl.heterograph( |
| 61 | + { |
| 62 | + ("_N", "_E", "_N"): test_set[i].edges(), |
| 63 | + ("_N", "self", "_N"): ( |
| 64 | + test_set[i].nodes(), |
| 65 | + test_set[i].nodes(), |
| 66 | + ), |
| 67 | + } |
| 68 | + ) |
| 69 | + g.ndata["label"] = test_set[i].ndata["label"] |
| 70 | + g.ndata["feat"] = test_set[i].ndata["feat"] |
| 71 | + g.ndata["_ID"] = test_set[i].ndata["_ID"] |
| 72 | + g.edges["_E"].data["_ID"] = test_set[i].edata["_ID"] |
| 73 | + test_set.graphs[i] = g |
52 | 74 |
|
53 | 75 | etypes = train_set[0].etypes
|
54 |
| - in_size = train_set[0].ndata['feat'].shape[1] |
55 |
| - out_size = train_set[0].ndata['label'].shape[1] |
| 76 | + in_size = train_set[0].ndata["feat"].shape[1] |
| 77 | + out_size = train_set[0].ndata["label"].shape[1] |
56 | 78 |
|
57 |
| - #prepare train, valid, and test dataloaders |
58 |
| - train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True) |
59 |
| - valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True) |
60 |
| - test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True) |
| 79 | + # prepare train, valid, and test dataloaders |
| 80 | + train_loader = DataLoader( |
| 81 | + train_set, |
| 82 | + batch_size=batch_size, |
| 83 | + collate_fn=batcher(device), |
| 84 | + shuffle=True, |
| 85 | + ) |
| 86 | + valid_loader = DataLoader( |
| 87 | + valid_set, |
| 88 | + batch_size=batch_size, |
| 89 | + collate_fn=batcher(device), |
| 90 | + shuffle=True, |
| 91 | + ) |
| 92 | + test_loader = DataLoader( |
| 93 | + test_set, |
| 94 | + batch_size=batch_size, |
| 95 | + collate_fn=batcher(device), |
| 96 | + shuffle=True, |
| 97 | + ) |
61 | 98 | return train_loader, valid_loader, test_loader, etypes, in_size, out_size
|
0 commit comments