Skip to content

Commit a842d6b

Browse files
committed
hotfix for pytorch missing locations
nncf/torch/sparsity/const/algo.py nncf/torch/sparsity/magnitude/algo.py nncf/torch/sparsity/rb/algo.py
1 parent fbaa831 commit a842d6b

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

nncf/torch/sparsity/const/algo.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Tuple
1212

1313
from nncf.common.graph import NNCFNode
14+
from nncf.common.graph.utils import get_weight_shape_legacy
1415
from nncf.common.sparsity.statistics import ConstSparsityStatistics
1516
from nncf.common.statistics import NNCFStatistics
1617
from nncf.common.utils.api_marker import api
@@ -26,7 +27,7 @@
2627
@PT_COMPRESSION_ALGORITHMS.register("const_sparsity")
2728
class ConstSparsityBuilder(BaseSparsityAlgoBuilder):
2829
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
29-
return BinaryMask(target_module_node.layer_attributes.get_weight_shape())
30+
return BinaryMask(get_weight_shape_legacy(target_module_node.layer_attributes))
3031

3132
def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
3233
return ConstSparsityController(model, self._sparsified_module_info)

nncf/torch/sparsity/magnitude/algo.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from nncf.api.compression import CompressionStage
1818
from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
1919
from nncf.common.graph import NNCFNode
20+
from nncf.common.graph.utils import get_weight_shape_legacy
2021
from nncf.common.initialization.batchnorm_adaptation import BatchnormAdaptationAlgorithm
2122
from nncf.common.schedulers import StubCompressionScheduler
2223
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
@@ -44,7 +45,7 @@
4445
@PT_COMPRESSION_ALGORITHMS.register("magnitude_sparsity")
4546
class MagnitudeSparsityBuilder(BaseSparsityAlgoBuilder):
4647
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
47-
return BinaryMask(target_module_node.layer_attributes.get_weight_shape())
48+
return BinaryMask(get_weight_shape_legacy(target_module_node.layer_attributes))
4849

4950
def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
5051
return MagnitudeSparsityController(model, self._sparsified_module_info, self.config)

nncf/torch/sparsity/rb/algo.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nncf.api.compression import CompressionStage
1919
from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
2020
from nncf.common.graph import NNCFNode
21+
from nncf.common.graph.utils import get_weight_shape_legacy
2122
from nncf.common.schedulers import StubCompressionScheduler
2223
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
2324
from nncf.common.sparsity.statistics import RBSparsityStatistics
@@ -44,7 +45,7 @@
4445
class RBSparsityBuilder(BaseSparsityAlgoBuilder):
4546
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
4647
return RBSparsifyingWeight(
47-
target_module_node.layer_attributes.get_weight_shape(),
48+
get_weight_shape_legacy(target_module_node.layer_attributes),
4849
frozen=False,
4950
compression_lr_multiplier=compression_lr_multiplier,
5051
)

0 commit comments

Comments
 (0)