Skip to content

Commit 4ff1361

Browse files
authored
[Fix] Update pre-commit-config-zh-cn.yaml and add typehints for PointNet2SAMSG (open-mmlab#2396)
1 parent 9f61eff commit 4ff1361

File tree

3 files changed

+43
-29
lines changed

3 files changed

+43
-29
lines changed

.pre-commit-config-zh-cn.yaml

+4-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
exclude: ^tests/data/
21
repos:
32
- repo: https://gitee.com/openmmlab/mirrors-flake8
43
rev: 5.0.4
@@ -25,6 +24,10 @@ repos:
2524
args: ["--remove"]
2625
- id: mixed-line-ending
2726
args: ["--fix=lf"]
27+
- repo: https://gitee.com/openmmlab/mirrors-codespell
28+
rev: v2.2.1
29+
hooks:
30+
- id: codespell
2831
- repo: https://gitee.com/openmmlab/mirrors-mdformat
2932
rev: 0.7.9
3033
hooks:
@@ -34,20 +37,11 @@ repos:
3437
- mdformat-openmmlab
3538
- mdformat_frontmatter
3639
- linkify-it-py
37-
- repo: https://gitee.com/openmmlab/mirrors-codespell
38-
rev: v2.2.1
39-
hooks:
40-
- id: codespell
4140
- repo: https://gitee.com/openmmlab/mirrors-docformatter
4241
rev: v1.3.1
4342
hooks:
4443
- id: docformatter
4544
args: ["--in-place", "--wrap-descriptions", "79"]
46-
- repo: https://gitee.com/openmmlab/mirrors-pyupgrade
47-
rev: v3.0.0
48-
hooks:
49-
- id: pyupgrade
50-
args: ["--py36-plus"]
5145
- repo: https://gitee.com/openmmlab/pre-commit-hooks
5246
rev: v0.2.0
5347
hooks:

mmdet3d/models/backbones/pointnet2_sa_msg.py

+36-18
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Tuple
3+
24
import torch
35
from mmcv.cnn import ConvModule
46
from torch import nn as nn
57

68
from mmdet3d.models.layers.pointnet_modules import build_sa_module
79
from mmdet3d.registry import MODELS
10+
from mmdet3d.utils import OptConfigType
811
from .base_pointnet import BasePointNet
912

13+
ThreeTupleIntType = Tuple[Tuple[Tuple[int, int, int]]]
14+
TwoTupleIntType = Tuple[Tuple[int, int, int]]
15+
TwoTupleStrType = Tuple[Tuple[str]]
16+
1017

1118
@MODELS.register_module()
1219
class PointNet2SAMSG(BasePointNet):
@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
2229
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
2330
aggregation_channels (tuple[int]): Out channels of aggregation
2431
multi-scale grouping features.
25-
fps_mods (tuple[int]): Mod of FPS for each SA module.
32+
fps_mods Sequence[Tuple[str]]: Mod of FPS for each SA module.
2633
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
2734
points which each SA module samples.
2835
dilated_group (tuple[bool]): Whether to use dilated ball query for
@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
3845
"""
3946

4047
def __init__(self,
41-
in_channels,
42-
num_points=(2048, 1024, 512, 256),
43-
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
44-
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
45-
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
46-
((64, 64, 128), (64, 64, 128), (64, 96, 128)),
47-
((128, 128, 256), (128, 192, 256), (128, 256,
48-
256))),
49-
aggregation_channels=(64, 128, 256),
50-
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
51-
fps_sample_range_lists=((-1), (-1), (512, -1)),
52-
dilated_group=(True, True, True),
53-
out_indices=(2, ),
54-
norm_cfg=dict(type='BN2d'),
55-
sa_cfg=dict(
48+
in_channels: int,
49+
num_points: Tuple[int] = (2048, 1024, 512, 256),
50+
radii: Tuple[Tuple[float, float, float]] = (
51+
(0.2, 0.4, 0.8),
52+
(0.4, 0.8, 1.6),
53+
(1.6, 3.2, 4.8),
54+
),
55+
num_samples: TwoTupleIntType = ((32, 32, 64), (32, 32, 64),
56+
(32, 32, 32)),
57+
sa_channels: ThreeTupleIntType = (((16, 16, 32), (16, 16, 32),
58+
(32, 32, 64)),
59+
((64, 64, 128),
60+
(64, 64, 128), (64, 96,
61+
128)),
62+
((128, 128, 256),
63+
(128, 192, 256), (128, 256,
64+
256))),
65+
aggregation_channels: Tuple[int] = (64, 128, 256),
66+
fps_mods: TwoTupleStrType = (('D-FPS'), ('FS'), ('F-FPS',
67+
'D-FPS')),
68+
fps_sample_range_lists: TwoTupleIntType = ((-1), (-1), (512,
69+
-1)),
70+
dilated_group: Tuple[bool] = (True, True, True),
71+
out_indices: Tuple[int] = (2, ),
72+
norm_cfg: dict = dict(type='BN2d'),
73+
sa_cfg: dict = dict(
5674
type='PointSAModuleMSG',
5775
pool_mod='max',
5876
use_xyz=True,
5977
normalize_xyz=False),
60-
init_cfg=None):
78+
init_cfg: OptConfigType = None):
6179
super().__init__(init_cfg=init_cfg)
6280
self.num_sa = len(sa_channels)
6381
self.out_indices = out_indices
@@ -123,7 +141,7 @@ def __init__(self,
123141
bias=True))
124142
sa_in_channel = cur_aggregation_channel
125143

126-
def forward(self, points):
144+
def forward(self, points: torch.Tensor):
127145
"""Forward pass.
128146
129147
Args:

mmdet3d/models/layers/pointnet_modules/builder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from mmengine.registry import Registry
55
from torch import nn as nn
66

7-
SA_MODULES = Registry('point_sa_module')
7+
SA_MODULES = Registry(
8+
name='point_sa_module',
9+
locations=['mmdet3d.models.layers.pointnet_modules'])
810

911

1012
def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module:

0 commit comments

Comments
 (0)