File tree 5 files changed +8
-7
lines changed
experimental/torch/sparsity/movement
5 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -157,11 +157,13 @@ def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[i
157
157
return [
158
158
layer_attributes .out_channels ,
159
159
layer_attributes .in_channels // layer_attributes .groups ,
160
- * layer_attributes .kernel_size ]
160
+ * layer_attributes .kernel_size ,
161
+ ]
161
162
return [
162
163
layer_attributes .in_channels ,
163
164
layer_attributes .out_channels // layer_attributes .groups ,
164
- * layer_attributes .kernel_size ]
165
+ * layer_attributes .kernel_size ,
166
+ ]
165
167
166
168
if isinstance (layer_attributes , GroupNormLayerAttributes ):
167
169
return [layer_attributes .num_channels ]
@@ -204,4 +206,3 @@ def get_num_filters_legacy(layer_attributes: WeightedLayerAttributes) -> int:
204
206
"""
205
207
weight_shape = layer_attributes .get_weight_shape_legacy ()
206
208
return weight_shape [layer_attributes .get_target_dim_for_compression_legacy ()]
207
-
Original file line number Diff line number Diff line change 18
18
19
19
import nncf
20
20
from nncf .common .graph import NNCFNode
21
- from nncf .common .graph .utils import get_weight_shape_legacy
22
21
from nncf .common .graph .utils import get_bias_shape_legacy
22
+ from nncf .common .graph .utils import get_weight_shape_legacy
23
23
from nncf .experimental .torch .sparsity .movement .functions import binary_mask_by_threshold
24
24
from nncf .torch .layer_utils import COMPRESSION_MODULES
25
25
from nncf .torch .layer_utils import CompressionParameter
Original file line number Diff line number Diff line change 18
18
19
19
from nncf .common .graph .graph import NNCFNodeName
20
20
from nncf .common .graph .layer_attributes import LinearLayerAttributes
21
- from nncf .common .graph .utils import get_weight_shape_legacy
22
21
from nncf .common .graph .utils import get_bias_shape_legacy
22
+ from nncf .common .graph .utils import get_weight_shape_legacy
23
23
from nncf .common .logging import nncf_logger
24
24
from nncf .experimental .common .pruning .nodes_grouping import get_pruning_groups
25
25
from nncf .experimental .common .pruning .nodes_grouping import select_largest_groups
Original file line number Diff line number Diff line change 35
35
from nncf .common .graph .patterns .manager import TargetDevice
36
36
from nncf .common .graph .transformations .commands import TargetType
37
37
from nncf .common .graph .utils import get_first_nodes_of_type
38
- from nncf .common .graph .utils import get_weight_shape_legacy
39
38
from nncf .common .graph .utils import get_target_dim_for_compression_legacy
39
+ from nncf .common .graph .utils import get_weight_shape_legacy
40
40
from nncf .common .hardware .config import HWConfig
41
41
from nncf .common .hardware .config import HWConfigType
42
42
from nncf .common .hardware .config import get_hw_config_type
Original file line number Diff line number Diff line change 18
18
19
19
import nncf
20
20
from nncf .common .graph .layer_attributes import WeightedLayerAttributes
21
- from nncf .common .graph .utils import get_weight_shape_legacy
22
21
from nncf .common .graph .utils import get_target_dim_for_compression_legacy
22
+ from nncf .common .graph .utils import get_weight_shape_legacy
23
23
from nncf .common .quantization .initialization .range import RangeInitCollectorParams
24
24
from nncf .common .quantization .initialization .range import RangeInitConfig
25
25
from nncf .common .quantization .initialization .range import RangeInitParams
You can’t perform that action at this time.
0 commit comments