Skip to content

Commit c9621f2

Browse files
committed
refractor the position of register.py
1 parent 77f1e7a commit c9621f2

30 files changed

+35
-63
lines changed

graphgym/config.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from graphgym.utils.io import makedirs_rm_exist
66

7-
from graphgym.contrib.config import *
8-
import graphgym.models.register as register
7+
import graphgym.register as register
98

109
# Global config object
1110
cfg = CN()

graphgym/contrib/act/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from graphgym.config import cfg
4-
from graphgym.models.register import register_act
4+
from graphgym.register import register_act
55

66

77
class SWISH(nn.Module):

graphgym/contrib/config/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from yacs.config import CfgNode as CN
22

3-
from graphgym.models.register import register_config
3+
from graphgym.register import register_config
44

55

66
def set_cfg_example(cfg):

graphgym/contrib/feature_augment/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import networkx as nx
22

3-
from graphgym.models.register import register_feature_augment
3+
from graphgym.register import register_feature_augment
44

55

66
def example_node_augmentation_func(graph, **kwargs):

graphgym/contrib/feature_encoder/example.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import torch
2-
import torch.nn as nn
3-
from graphgym.config import cfg
4-
from graphgym.models.register import (register_node_encoder,
5-
register_edge_encoder)
2+
from graphgym.register import (register_node_encoder,
3+
register_edge_encoder)
64

75
from ogb.utils.features import get_bond_feature_dims
86

graphgym/contrib/head/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from graphgym.models.register import register_head
3+
from graphgym.register import register_head
44

55

66
class ExampleNodeHead(nn.Module):

graphgym/contrib/layer/attconv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from torch_geometric.nn.inits import glorot, zeros
1010
from graphgym.config import cfg
11-
from graphgym.models.register import register_layer
11+
from graphgym.register import register_layer
1212

1313

1414
class GeneralAddAttConvLayer(MessagePassing):

graphgym/contrib/layer/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torch_geometric.nn.inits import glorot, zeros
77
from graphgym.config import cfg
8-
from graphgym.models.register import register_layer
8+
from graphgym.register import register_layer
99

1010

1111
# Note: A registered GNN layer should take 'batch' as input

graphgym/contrib/layer/generalconv.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import torch
22
import torch.nn as nn
33
from torch.nn import Parameter
4-
import torch.nn.functional as F
54
from torch_scatter import scatter_add
65
from torch_geometric.nn.conv import MessagePassing
7-
from torch_geometric.utils import add_remaining_self_loops, softmax
6+
from torch_geometric.utils import add_remaining_self_loops
87

98
from torch_geometric.nn.inits import glorot, zeros
109
from graphgym.config import cfg
11-
from graphgym.models.register import register_layer
1210

1311

1412
class GeneralConvLayer(MessagePassing):

graphgym/contrib/layer/generalconv_ogb.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torch_geometric.nn.inits import glorot, zeros
99
from graphgym.config import cfg
1010

11-
from graphgym.models.register import register_layer
11+
from graphgym.register import register_layer
1212

13-
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
13+
from ogb.utils.features import get_bond_feature_dims
1414

1515
full_bond_feature_dims = get_bond_feature_dims()
1616

graphgym/contrib/layer/idconv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch_geometric.nn.inits import glorot, zeros, reset
1111

1212
from graphgym.config import cfg
13-
from graphgym.models.register import register_layer
13+
from graphgym.register import register_layer
1414

1515

1616
class GeneralIDConvLayer(MessagePassing):

graphgym/contrib/layer/sageinitconv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from torch_geometric.nn.conv import MessagePassing
66
from torch_geometric.utils import add_remaining_self_loops
77

8-
from torch_geometric.nn.inits import glorot, zeros, reset
9-
from graphgym.models.register import register_layer
8+
from torch_geometric.nn.inits import glorot, zeros
9+
from graphgym.register import register_layer
1010

1111

1212
class SAGEConvLayer(MessagePassing):

graphgym/contrib/loader/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from deepsnap.dataset import GraphDataset
22
from torch_geometric.datasets import *
33

4-
from graphgym.models.register import register_loader
4+
from graphgym.register import register_loader
55

66

77
def load_dataset_example(format, name, dataset_dir):

graphgym/contrib/loss/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from graphgym.models.register import register_loss
3+
from graphgym.register import register_loss
44

55
from graphgym.config import cfg
66

graphgym/contrib/network/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from graphgym.config import cfg
77
from graphgym.models.head import head_dict
8-
from graphgym.models.register import register_network
8+
from graphgym.register import register_network
99

1010

1111
class ExampleGNN(torch.nn.Module):

graphgym/contrib/optimizer/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.optim as optim
22

3-
from graphgym.models.register import register_optimizer, register_scheduler
3+
from graphgym.register import register_optimizer, register_scheduler
44

55
from graphgym.config import cfg
66

graphgym/contrib/pooling/example.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import torch
2-
import torch.nn as nn
31
from torch_scatter import scatter
4-
from graphgym.config import cfg
5-
from graphgym.models.register import register_pooling
2+
from graphgym.register import register_pooling
63

74

85
def global_example_pool(x, batch, size=None):

graphgym/contrib/stage/example.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import torch
21
import torch.nn as nn
32
import torch.nn.functional as F
4-
from torch_scatter import scatter
53
from graphgym.config import cfg
6-
from graphgym.models.register import register_stage
4+
from graphgym.register import register_stage
75

86
import graphgym.models.gnn as gnn
97

graphgym/loader.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import networkx as nx
32
import time
43
import logging
@@ -15,14 +14,11 @@
1514
import graphgym.models.feature_augment as preprocess
1615
from graphgym.models.transform import (ego_nets, remove_node_feature,
1716
edge_nets, path_len)
18-
from graphgym.contrib.loader import *
19-
import graphgym.models.register as register
17+
import graphgym.register as register
2018

2119
from ogb.graphproppred import PygGraphPropPredDataset
2220
from deepsnap.batch import Batch
2321

24-
import pdb
25-
2622

2723
def load_pyg(name, dataset_dir):
2824
'''

graphgym/loss.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from graphgym.contrib.loss import *
6-
import graphgym.models.register as register
5+
import graphgym.register as register
76
from graphgym.config import cfg
87

98
def compute_loss(pred, true):

graphgym/model_builder.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from graphgym.config import cfg
44
from graphgym.models.gnn import GNN
55

6-
from graphgym.contrib.network import *
7-
import graphgym.models.register as register
6+
import graphgym.register as register
87

98
network_dict = {
109
'gnn': GNN,

graphgym/models/act.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import torch
21
import torch.nn as nn
3-
import torch.nn.functional as F
42
from graphgym.config import cfg
5-
from graphgym.contrib.act import *
6-
import graphgym.models.register as register
3+
import graphgym.register as register
74

85
act_dict = {
96
'relu': nn.ReLU(inplace=cfg.mem.inplace),

graphgym/models/feature_augment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from graphgym.config import cfg
1010
from graphgym.contrib.transform.identity import compute_identity
1111

12-
import graphgym.models.register as register
12+
import graphgym.register as register
1313

1414

1515
def _key(key, as_label=False):

graphgym/models/feature_encoder.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
33

4-
from graphgym.contrib.feature_encoder import *
5-
import graphgym.models.register as register
4+
import graphgym.register as register
65

76
# Used for the OGB Encoders
87
full_atom_feature_dims = get_atom_feature_dims()

graphgym/models/gnn.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from graphgym.init import init_weights
1212
from graphgym.models.feature_encoder import node_encoder_dict, edge_encoder_dict
1313

14-
from graphgym.contrib.stage import *
15-
import graphgym.models.register as register
14+
import graphgym.register as register
1615

1716

1817
########### Layer ############

graphgym/models/head.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@
55

66
import torch
77
import torch.nn as nn
8-
import torch.nn.functional as F
98

109
from graphgym.config import cfg
11-
from graphgym.models.layer import layer_dict, MLP, GeneralLayer
10+
from graphgym.models.layer import MLP
1211
from graphgym.models.pooling import pooling_dict
1312

14-
from graphgym.contrib.head import *
15-
import graphgym.models.register as register
16-
17-
import pdb
13+
import graphgym.register as register
1814

1915

2016
########### Head ############
@@ -61,8 +57,8 @@ def __init__(self, dim_in, dim_out):
6157
num_layers=cfg.gnn.layers_post_mp,
6258
bias=True)
6359
# requires parameter
64-
self.decode_module = lambda v1, v2: torch.sigmoid(
65-
self.layer_post_mp(torch.cat((v1, v2), dim=-1)))
60+
self.decode_module = lambda v1, v2: \
61+
self.layer_post_mp(torch.cat((v1, v2), dim=-1))
6662
else:
6763
if dim_out > 1:
6864
raise ValueError(

graphgym/models/layer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from graphgym.contrib.layer.generalconv import (GeneralConvLayer,
99
GeneralEdgeConvLayer)
1010

11-
from graphgym.contrib.layer import *
12-
import graphgym.models.register as register
11+
import graphgym.register as register
1312

1413

1514
## General classes

graphgym/models/pooling.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from torch_scatter import scatter
33
from graphgym.config import cfg
44

5-
from graphgym.contrib.pooling import *
6-
import graphgym.models.register as register
5+
import graphgym.register as register
76

87

98
# Pooling options (pool nodes into graph representations)

graphgym/optimizer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import torch.optim as optim
44

5-
from graphgym.contrib.optimizer import *
6-
import graphgym.models.register as register
5+
import graphgym.register as register
76

87

98
def create_optimizer(params):
File renamed without changes.

0 commit comments

Comments
 (0)