1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Tuple
3
+
2
4
import torch
3
5
from mmcv .cnn import ConvModule
4
6
from torch import nn as nn
5
7
6
8
from mmdet3d .models .layers .pointnet_modules import build_sa_module
7
9
from mmdet3d .registry import MODELS
10
+ from mmdet3d .utils import OptConfigType
8
11
from .base_pointnet import BasePointNet
9
12
13
+ ThreeTupleIntType = Tuple [Tuple [Tuple [int , int , int ]]]
14
+ TwoTupleIntType = Tuple [Tuple [int , int , int ]]
15
+ TwoTupleStrType = Tuple [Tuple [str ]]
16
+
10
17
11
18
@MODELS .register_module ()
12
19
class PointNet2SAMSG (BasePointNet ):
@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
22
29
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
23
30
aggregation_channels (tuple[int]): Out channels of aggregation
24
31
multi-scale grouping features.
25
- fps_mods (tuple[int]) : Mod of FPS for each SA module.
32
+ fps_mods Sequence[Tuple[str]] : Mod of FPS for each SA module.
26
33
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
27
34
points which each SA module samples.
28
35
dilated_group (tuple[bool]): Whether to use dilated ball query for
@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
38
45
"""
39
46
40
47
def __init__ (self ,
41
- in_channels ,
42
- num_points = (2048 , 1024 , 512 , 256 ),
43
- radii = ((0.2 , 0.4 , 0.8 ), (0.4 , 0.8 , 1.6 ), (1.6 , 3.2 , 4.8 )),
44
- num_samples = ((32 , 32 , 64 ), (32 , 32 , 64 ), (32 , 32 , 32 )),
45
- sa_channels = (((16 , 16 , 32 ), (16 , 16 , 32 ), (32 , 32 , 64 )),
46
- ((64 , 64 , 128 ), (64 , 64 , 128 ), (64 , 96 , 128 )),
47
- ((128 , 128 , 256 ), (128 , 192 , 256 ), (128 , 256 ,
48
- 256 ))),
49
- aggregation_channels = (64 , 128 , 256 ),
50
- fps_mods = (('D-FPS' ), ('FS' ), ('F-FPS' , 'D-FPS' )),
51
- fps_sample_range_lists = ((- 1 ), (- 1 ), (512 , - 1 )),
52
- dilated_group = (True , True , True ),
53
- out_indices = (2 , ),
54
- norm_cfg = dict (type = 'BN2d' ),
55
- sa_cfg = dict (
48
+ in_channels : int ,
49
+ num_points : Tuple [int ] = (2048 , 1024 , 512 , 256 ),
50
+ radii : Tuple [Tuple [float , float , float ]] = (
51
+ (0.2 , 0.4 , 0.8 ),
52
+ (0.4 , 0.8 , 1.6 ),
53
+ (1.6 , 3.2 , 4.8 ),
54
+ ),
55
+ num_samples : TwoTupleIntType = ((32 , 32 , 64 ), (32 , 32 , 64 ),
56
+ (32 , 32 , 32 )),
57
+ sa_channels : ThreeTupleIntType = (((16 , 16 , 32 ), (16 , 16 , 32 ),
58
+ (32 , 32 , 64 )),
59
+ ((64 , 64 , 128 ),
60
+ (64 , 64 , 128 ), (64 , 96 ,
61
+ 128 )),
62
+ ((128 , 128 , 256 ),
63
+ (128 , 192 , 256 ), (128 , 256 ,
64
+ 256 ))),
65
+ aggregation_channels : Tuple [int ] = (64 , 128 , 256 ),
66
+ fps_mods : TwoTupleStrType = (('D-FPS' ), ('FS' ), ('F-FPS' ,
67
+ 'D-FPS' )),
68
+ fps_sample_range_lists : TwoTupleIntType = ((- 1 ), (- 1 ), (512 ,
69
+ - 1 )),
70
+ dilated_group : Tuple [bool ] = (True , True , True ),
71
+ out_indices : Tuple [int ] = (2 , ),
72
+ norm_cfg : dict = dict (type = 'BN2d' ),
73
+ sa_cfg : dict = dict (
56
74
type = 'PointSAModuleMSG' ,
57
75
pool_mod = 'max' ,
58
76
use_xyz = True ,
59
77
normalize_xyz = False ),
60
- init_cfg = None ):
78
+ init_cfg : OptConfigType = None ):
61
79
super ().__init__ (init_cfg = init_cfg )
62
80
self .num_sa = len (sa_channels )
63
81
self .out_indices = out_indices
@@ -123,7 +141,7 @@ def __init__(self,
123
141
bias = True ))
124
142
sa_in_channel = cur_aggregation_channel
125
143
126
- def forward (self , points ):
144
+ def forward (self , points : torch . Tensor ):
127
145
"""Forward pass.
128
146
129
147
Args:
0 commit comments