Skip to content

Commit f0cb70c

Browse files
authored
[Torch][WeightCompression] Add Scale Estimation data-aware support (#3179)
### Changes Added data-aware support for the Torch backend for WeightCompression with Scale Estimation. Introduced support for MeanVarianceReducer, MaxVarianceReducer, and MeanAbsMaxReducer. Incorporated `torch.inference_mode()` context for WeightCompression. ### Reason for changes These changes enable the utilization of data-aware Scale Estimation for the Torch backend, specifically leveraging CUDA devices for improved performance. ### Related tickets Ticket ID: 158974 ### Tests Added a template for WeightCompression tests for both Torch and OV backends, covering data-aware and Scale Estimation scenarios. Extended the test scope to include `tinyllama_data_aware` and `tinyllama_scale_estimation_per_channel` for Torch. Added a new test case `tinyllama_scale_estimation_group_size_64` for both Torch and OV backends. ### Performance Metrics Note: All CUDA results are obtained locally on a single RTX 3090. Model | Backend | Metric Name | Metric Value | Num int4 | Num int8 | Compression Time (from Performance Job) | RAM MiB (from Performance Job) -- | -- | -- | -- | -- | -- | -- | -- tinyllama_data_aware | OV | Similarity | 0.8577 | 94 | 124 | 0:01:28 | 8545 tinyllama_data_aware | TORCH | Similarity | 0.8577 | 94 | 124 | 0:02:15 | 1225 tinyllama_data_aware | TORCH (CUDA) | Similarity | 0.8577 | 94 | 124 | 0:00:28 | - tinyllama_scale_estimation_per_channel | OV | Similarity | 0.8139 | 188 | 124 | 0:02:57 | 8681 tinyllama_scale_estimation_per_channel | TORCH | Similarity | 0.8139 | 188 | 124 | 0:03:25 | 5472 tinyllama_scale_estimation_per_channel | TORCH (CUDA) | Similarity | 0.8139 | 188 | 124 | 0:00:35 | - tinyllama_scale_estimation_group_size_64 | OV | Similarity | 0.8566 | 94 | 124 | 0:04:17 | 8681 tinyllama_scale_estimation_group_size_64 | TORCH | Similarity | 0.8566 | 94 | 124 | 0:04:01 | 5575 tinyllama_scale_estimation_group_size_64 | TORCH (CUDA) | Similarity | 0.8566 | 94 | 124 | 0:00:36 | -
1 parent 4275513 commit f0cb70c

25 files changed

+748
-210
lines changed

nncf/experimental/common/tensor_statistics/collectors.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -464,18 +464,27 @@ def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
464464

465465

466466
class MeanVarianceReducer(TensorReducerBase):
467-
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
468-
raise NotImplementedError()
467+
def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
468+
x = x[0]
469+
reduction_axes = self._get_reduction_axes(x)
470+
variance = fns.var(x, reduction_axes)
471+
return [fns.mean(variance)]
469472

470473

471474
class MaxVarianceReducer(TensorReducerBase):
472-
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
473-
raise NotImplementedError()
475+
def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
476+
x = x[0]
477+
reduction_axes = self._get_reduction_axes(x)
478+
variance = fns.var(x, reduction_axes)
479+
return [fns.max(variance)]
474480

475481

476482
class MeanAbsMaxReducer(TensorReducerBase):
477-
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
478-
raise NotImplementedError()
483+
def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
484+
x = fns.abs(x[0])
485+
reduction_axes = self._get_reduction_axes(x)
486+
abs_max = fns.max(x, reduction_axes, keepdims=self._keepdims)
487+
return [fns.mean(abs_max)]
479488

480489

481490
class QuantileReducerBase(TensorReducerBase):

nncf/experimental/torch/fx/quantization/quantize_model.py

-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def compress_weights_impl(
135135
"""
136136
Implementation of the `compress_weights()` method for the Torch Fx backend.
137137
"""
138-
139138
compression_algorithm = WeightCompression(
140139
mode,
141140
ratio,

nncf/quantization/algorithms/weight_compression/algorithm.py

+18-26
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import operator
1313
from collections import OrderedDict
1414
from collections import defaultdict
15-
from functools import partial
1615
from functools import reduce
1716
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar
1817

@@ -266,6 +265,14 @@ def __init__(
266265
subset_size=gptq_params.subset_size,
267266
scale_estimation=self._scale_estimation,
268267
)
268+
if self._scale_estimation:
269+
scale_estimation_params = self._advanced_parameters.scale_estimation_params
270+
self._scale_estimation_algo = ScaleEstimation(
271+
scale_estimation_params.subset_size,
272+
scale_estimation_params.initial_steps,
273+
scale_estimation_params.scale_steps,
274+
scale_estimation_params.weight_penalty,
275+
)
269276

270277
self._data_aware_mixed_precision = (
271278
self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0
@@ -616,18 +623,13 @@ def apply(
616623
)
617624
else:
618625
if self._scale_estimation:
619-
scale_estimation_params = self._advanced_parameters.scale_estimation_params
620-
scales, zero_points = ScaleEstimation(
621-
model,
622-
self._backend_entity.name_to_node_mapping,
623-
all_weight_params,
624-
nodes_to_compress,
625-
statistics,
626-
scale_estimation_params.subset_size,
627-
scale_estimation_params.initial_steps,
628-
scale_estimation_params.scale_steps,
629-
scale_estimation_params.weight_penalty,
630-
).apply(model, graph)
626+
scales, zero_points = self._scale_estimation_algo.apply(
627+
model=model,
628+
graph=graph,
629+
all_weight_params=all_weight_params,
630+
statistics=statistics,
631+
backend_entity=self._backend_entity,
632+
)
631633

632634
if self._lora_correction:
633635
lora_correction_params = self._advanced_parameters.lora_correction_params
@@ -702,8 +704,6 @@ def get_matmul_input_to_output_nodes_map(
702704
"""
703705
matmul_input_to_output_nodes_map = defaultdict(list)
704706
for node in matmul_nodes:
705-
if node.layer_attributes.input_attributes["transpose"]: # It works only for OV
706-
raise nncf.UnsupportedModelError("Transposed input is not supported")
707707
act_node, output_port_id = self._get_activation_node_and_port(node, graph)
708708
matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node)
709709
return matmul_input_to_output_nodes_map
@@ -811,16 +811,6 @@ def _get_statistics_for_weights_compression(
811811
:return: Collected statistics.
812812
"""
813813

814-
def input_filter_func(point, port_id):
815-
# For the floating-point statistics collected in POST_LAYER style,
816-
# we also need to determine the output port id.
817-
# For the cases when the layer has more than one (0) output port.
818-
return (
819-
self._algorithm_key in point.algorithm_to_tensor_collectors
820-
and point.target_point.type == TargetType.POST_LAYER_OPERATION
821-
and point.target_point.port_id == port_id
822-
)
823-
824814
# For each node we store statistics in a WCTensorStatistics data-class. It contains the following fields:
825815
# mean_values=[mean_value_1, ..., mean_value_n]
826816
# shapes=[shape_1, ..., shape_n]
@@ -830,7 +820,9 @@ def input_filter_func(point, port_id):
830820
for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items():
831821
tensor_collectors = list(
832822
statistic_points.get_algo_statistics_for_node(
833-
act_node.node_name, partial(input_filter_func, port_id=output_port_id), self._algorithm_key
823+
act_node.node_name,
824+
self._backend_entity.get_filter_fn_for_statistics(output_port_id, self._algorithm_key),
825+
self._algorithm_key,
834826
)
835827
)
836828
# Statistics could be empty in case when the statistics is registered for another algorithm,

nncf/quantization/algorithms/weight_compression/backend.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111

1212
from abc import ABC
1313
from abc import abstractmethod
14-
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar
14+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
1515

1616
from nncf.common.graph import NNCFGraph
1717
from nncf.common.graph import NNCFNode
1818
from nncf.common.graph.operator_metatypes import OperatorMetatype
1919
from nncf.common.graph.transformations.commands import TargetPoint
2020
from nncf.common.graph.transformations.commands import TargetType
2121
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
22+
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
2223
from nncf.experimental.common.tensor_statistics.collectors import HAWQAggregator
2324
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
2425
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
@@ -234,6 +235,17 @@ def dump_parameters(
234235
:param path: Optional list of the paths.
235236
"""
236237

238+
@staticmethod
239+
@abstractmethod
240+
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
241+
"""
242+
Returns backend-specific callable to filter statistic containers according to its statistic point.
243+
244+
:param activation_port_id: Activation port id for the statistic collection target node.
245+
:param algorithm_key: Current algorithm key.
246+
:return: Backend-specific callable to filter statistic containers according to its statistic point.
247+
"""
248+
237249

238250
class AWQAlgoBackend(WeightCompressionAlgoBackend):
239251
@staticmethod

nncf/quantization/algorithms/weight_compression/mixed_precision.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,18 @@ class DataBasedCriterion(DataFreeCriterion, ABC):
211211

212212
@property
213213
def available_backends(self) -> List[BackendType]:
214-
return [BackendType.OPENVINO]
214+
return [BackendType.OPENVINO, BackendType.TORCH]
215215

216216
def _set_backend_entity(self, model: TModel) -> None:
217217
model_backend = get_backend(model)
218218
if model_backend == BackendType.OPENVINO:
219219
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVMixedPrecisionAlgoBackend
220220

221221
self._backend_entity = OVMixedPrecisionAlgoBackend(model)
222+
elif model_backend == BackendType.TORCH:
223+
from nncf.quantization.algorithms.weight_compression.torch_backend import PTMixedPrecisionAlgoBackend
224+
225+
self._backend_entity = PTMixedPrecisionAlgoBackend()
222226
else:
223227
raise nncf.UnsupportedBackendError(
224228
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
@@ -303,21 +307,12 @@ def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -
303307
def _get_statistics_for_node(
304308
self, statistic_points: StatisticPointsContainer, node: NNCFNode, nncf_graph: NNCFGraph, stat_key: str
305309
) -> List[Tensor]:
306-
act_node, output_port_id = self._get_activation_node_and_port(node, nncf_graph)
307-
308-
def input_filter_func(point):
309-
# For the floating-point statistics collected in POST_LAYER style,
310-
# we also need to determine the output port id.
311-
# For the cases when the layer has more than one (0) output port.
312-
return (
313-
self._algorithm_key in point.algorithm_to_tensor_collectors
314-
and point.target_point.type == TargetType.POST_LAYER_OPERATION
315-
and point.target_point.port_id == output_port_id
316-
)
317-
310+
act_node, act_port_id = self._get_activation_node_and_port(node, nncf_graph)
318311
stats = []
319312
for tensor_collector in statistic_points.get_algo_statistics_for_node(
320-
act_node.node_name, input_filter_func, self._algorithm_key
313+
act_node.node_name,
314+
self._backend_entity.get_filter_fn_for_statistics(act_port_id, self._algorithm_key),
315+
self._algorithm_key,
321316
):
322317
statistics = tensor_collector.get_statistics()
323318
for data in statistics.get_data().values():

nncf/quantization/algorithms/weight_compression/openvino_backend.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
from typing import Dict, Iterable, List, Optional, Tuple
11+
from typing import Callable, Dict, Iterable, List, Optional, Tuple
1212

1313
import openvino as ov
1414
from openvino.runtime import opset13 as opset
@@ -19,6 +19,7 @@
1919
from nncf.common.graph.operator_metatypes import OperatorMetatype
2020
from nncf.common.graph.transformations.commands import TargetType
2121
from nncf.common.graph.utils import get_reduction_axes
22+
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
2223
from nncf.common.utils.caching import disable_results_caching
2324
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
2425
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
@@ -109,6 +110,8 @@ def mean_statistic_collector(
109110

110111
@staticmethod
111112
def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
113+
if node.layer_attributes.input_attributes["transpose"]:
114+
raise nncf.UnsupportedModelError("Transposed input is not supported")
112115
constant_ports = node.layer_attributes.get_const_port_ids()
113116
activation_ports = [
114117
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports
@@ -348,6 +351,17 @@ def dump_parameters(
348351
) -> None:
349352
dump_parameters(model, parameters, algo_name, path)
350353

354+
@staticmethod
355+
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
356+
def filter_func(point: StatisticPoint) -> bool:
357+
return (
358+
algorithm_key in point.algorithm_to_tensor_collectors
359+
and point.target_point.type == TargetType.POST_LAYER_OPERATION
360+
and point.target_point.port_id == activation_port_id
361+
)
362+
363+
return filter_func
364+
351365

352366
class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend):
353367
@staticmethod

nncf/quantization/algorithms/weight_compression/scale_estimation.py

+23-33
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,17 @@
1010
# limitations under the License.
1111

1212
from copy import deepcopy
13-
from typing import Any, Dict, List, Optional, Tuple, TypeVar
13+
from typing import Dict, List, Optional, Tuple, TypeVar
1414

1515
import nncf
16-
from nncf import Dataset
1716
from nncf.common.graph.graph import NNCFGraph
18-
from nncf.common.graph.graph import NNCFNode
1917
from nncf.common.logging.track_progress import track
20-
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
2118
from nncf.common.utils.backend import BackendType
2219
from nncf.common.utils.backend import get_backend
2320
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
2421
from nncf.parameters import CompressWeightsMode
2522
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
23+
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
2624
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
2725
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
2826
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_normalized_weight_and_fp4_scale
@@ -45,70 +43,57 @@ class ScaleEstimation:
4543

4644
def __init__(
4745
self,
48-
model: TModel,
49-
name_to_node_mapping: Dict[str, Any],
50-
all_weight_params: List[WeightCompressionParameters],
51-
nodes_to_compress: List[NNCFNode],
52-
statistics: Dict[str, WCTensorStatistic],
5346
subset_size: int = 32,
5447
initial_steps: int = 5,
5548
scale_steps: int = 10,
5649
weight_penalty: float = -1.0,
5750
):
5851
"""
59-
:param model: Model for applying algorithm.
60-
:param name_to_node_mapping: Name to node mapping for updating node weights.
61-
:param all_weight_params: List of all weight parameters.
62-
:param nodes_to_compress: List of nodes for processing.
63-
:param statistics: Input activation statistics for each node.
6452
:param subset_size: The number of samples for scale estimation.
6553
:param initial_steps: The number of the steps for absmax scale rectification.
6654
:param scale_steps: The number of the steps for grid search scale rectification
6755
from 1.0 to 1.0 - 0.05 * scale_step.
6856
:param weight_penalty: coefficient for penalty between fp and compressed weights. If -1 then doesn't apply.
6957
"""
7058
super().__init__()
71-
self.name_to_node_mapping = name_to_node_mapping
72-
self._all_weight_params = all_weight_params
73-
self._nodes_to_compress = nodes_to_compress
74-
self._statistics = statistics
7559
self._subset_size = subset_size
7660
self._initial_steps = initial_steps
7761
self._scale_steps = scale_steps
7862
self._weight_penalty = weight_penalty
7963

80-
self._set_backend_entity(model)
81-
8264
@property
8365
def available_backends(self) -> List[BackendType]:
84-
return [BackendType.OPENVINO]
66+
return [BackendType.OPENVINO, BackendType.TORCH]
8567

8668
def _set_backend_entity(self, model: TModel) -> None:
8769
"""
8870
Creates a helper class with a backed-specific logic of the algorithm.
8971
9072
:param model: Backend-specific input model.
91-
:param all_weight_params: List of all weight parameters.
92-
:param nodes_to_compress: List of nodes for processing.
93-
:param activations: The input activations of the layers considered for compression.
9473
"""
95-
9674
model_backend = get_backend(model)
9775
if model_backend == BackendType.OPENVINO:
9876
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend
9977

100-
self._backend_entity = OVWeightCompressionAlgoBackend(model, self.name_to_node_mapping)
78+
self._backend_entity = OVWeightCompressionAlgoBackend(model)
79+
elif model_backend == BackendType.TORCH:
80+
from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend
81+
82+
self._backend_entity = PTWeightCompressionAlgoBackend()
10183
else:
10284
raise nncf.UnsupportedBackendError(
103-
"Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value)
85+
"Cannot return backend-specific Scale Estimation entity because {} is not supported!".format(
86+
model_backend.value
87+
)
10488
)
10589

10690
def apply(
10791
self,
10892
model: TModel,
10993
graph: NNCFGraph,
110-
statistic_points: Optional[StatisticPointsContainer] = None,
111-
dataset: Optional[Dataset] = None,
94+
all_weight_params: List[WeightCompressionParameters],
95+
statistics: Dict[str, WCTensorStatistic],
96+
backend_entity: Optional[WeightCompressionAlgoBackend] = None,
11297
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
11398
"""
11499
Estimates better scale for the int4 nodes in the model.
@@ -119,23 +104,28 @@ def apply(
119104
120105
:param model: Model for applying algorithm.
121106
:param graph: Model graph.
107+
:param all_weight_params: List of all weight parameters.
108+
:param statistics: Input activation statistics for each node.
122109
:param statistic_points: Statistic points with collected statistics values.
123110
:param dataset: A representative dataset for the calibration process.
111+
:param backend_entity: Weight compression algorithm backend.
124112
:return: Two dictionaries for estimated scales and zero points for each weight name.
125113
"""
126-
114+
self._backend_entity = backend_entity
115+
if self._backend_entity is None:
116+
self._set_backend_entity(model)
127117
scales, zero_points = dict(), dict()
128118

129-
for wp in track(self._all_weight_params, description="Applying Scale Estimation"):
119+
for wp in track(all_weight_params, description="Applying Scale Estimation"):
130120
weight_name = wp.weight_name
131121
node_name = wp.node_with_weight.node_name
132122
config = wp.compression_config
133123

134-
if config.num_bits != 4 or node_name not in self._statistics:
124+
if config.num_bits != 4 or node_name not in statistics:
135125
scales[weight_name] = None
136126
continue
137127

138-
stats = self._statistics[node_name]
128+
stats = statistics[node_name]
139129

140130
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
141131
if len(weight_data) != 1: # not supported by the algorithm

0 commit comments

Comments
 (0)