File tree 3 files changed +6
-3
lines changed
3 files changed +6
-3
lines changed Original file line number Diff line number Diff line change 11
11
from typing import Tuple
12
12
13
13
from nncf .common .graph import NNCFNode
14
+ from nncf .common .graph .utils import get_weight_shape_legacy
14
15
from nncf .common .sparsity .statistics import ConstSparsityStatistics
15
16
from nncf .common .statistics import NNCFStatistics
16
17
from nncf .common .utils .api_marker import api
26
27
@PT_COMPRESSION_ALGORITHMS .register ("const_sparsity" )
27
28
class ConstSparsityBuilder (BaseSparsityAlgoBuilder ):
28
29
def create_weight_sparsifying_operation (self , target_module_node : NNCFNode , compression_lr_multiplier : float ):
29
- return BinaryMask (target_module_node .layer_attributes . get_weight_shape ( ))
30
+ return BinaryMask (get_weight_shape_legacy ( target_module_node .layer_attributes ))
30
31
31
32
def _build_controller (self , model : NNCFNetwork ) -> PTCompressionAlgorithmController :
32
33
return ConstSparsityController (model , self ._sparsified_module_info )
Original file line number Diff line number Diff line change 17
17
from nncf .api .compression import CompressionStage
18
18
from nncf .common .accuracy_aware_training .training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
19
19
from nncf .common .graph import NNCFNode
20
+ from nncf .common .graph .utils import get_weight_shape_legacy
20
21
from nncf .common .initialization .batchnorm_adaptation import BatchnormAdaptationAlgorithm
21
22
from nncf .common .schedulers import StubCompressionScheduler
22
23
from nncf .common .sparsity .schedulers import SPARSITY_SCHEDULERS
44
45
@PT_COMPRESSION_ALGORITHMS .register ("magnitude_sparsity" )
45
46
class MagnitudeSparsityBuilder (BaseSparsityAlgoBuilder ):
46
47
def create_weight_sparsifying_operation (self , target_module_node : NNCFNode , compression_lr_multiplier : float ):
47
- return BinaryMask (target_module_node .layer_attributes . get_weight_shape ( ))
48
+ return BinaryMask (get_weight_shape_legacy ( target_module_node .layer_attributes ))
48
49
49
50
def _build_controller (self , model : NNCFNetwork ) -> PTCompressionAlgorithmController :
50
51
return MagnitudeSparsityController (model , self ._sparsified_module_info , self .config )
Original file line number Diff line number Diff line change 18
18
from nncf .api .compression import CompressionStage
19
19
from nncf .common .accuracy_aware_training .training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
20
20
from nncf .common .graph import NNCFNode
21
+ from nncf .common .graph .utils import get_weight_shape_legacy
21
22
from nncf .common .schedulers import StubCompressionScheduler
22
23
from nncf .common .sparsity .schedulers import SPARSITY_SCHEDULERS
23
24
from nncf .common .sparsity .statistics import RBSparsityStatistics
44
45
class RBSparsityBuilder (BaseSparsityAlgoBuilder ):
45
46
def create_weight_sparsifying_operation (self , target_module_node : NNCFNode , compression_lr_multiplier : float ):
46
47
return RBSparsifyingWeight (
47
- target_module_node .layer_attributes . get_weight_shape ( ),
48
+ get_weight_shape_legacy ( target_module_node .layer_attributes ),
48
49
frozen = False ,
49
50
compression_lr_multiplier = compression_lr_multiplier ,
50
51
)
You can’t perform that action at this time.
0 commit comments