Skip to content

Commit dfc5d78

Browse files
Move activations_to_wc_statistics to SE class (#3022)
### Changes As in the PR title.
1 parent e9d2415 commit dfc5d78

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

nncf/quantization/algorithms/weight_compression/gptq.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from nncf.common.graph import NNCFNode
1919
from nncf.common.logging.track_progress import track
2020
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
21-
from nncf.common.tensor_statistics.statistics import WCTensorStatistic
2221
from nncf.common.utils.backend import BackendType
2322
from nncf.common.utils.backend import get_backend
2423
from nncf.parameters import CompressWeightsMode
@@ -266,7 +265,7 @@ def _quantize_weights(
266265
else:
267266
if self._scale_estimation and block_compression_config.num_bits == 4:
268267
activations = [inp.squeeze()[:, (i1 + i) : (i1 + i + group_size)] for inp in inputs]
269-
wc_statistics = self._activations_to_wc_statistics(activations)
268+
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
270269
scale, zero_point = ScaleEstimation.calculate_quantization_params(
271270
self._backend_entity,
272271
wc_statistics,
@@ -327,15 +326,3 @@ def _quantize_weights(
327326
else:
328327
zero_points = None
329328
return scales, zero_points
330-
331-
@staticmethod
332-
def _activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic:
333-
# The code below mimics the logic from WeightCompression.get_statistic_points
334-
mean_values = []
335-
shapes = []
336-
for act in activations:
337-
shapes.append(act.shape)
338-
reduction_shape = tuple(range(act.ndim - 1))
339-
mean_values.append(fns.mean(act, axis=reduction_shape))
340-
wc_statistics = WCTensorStatistic(mean_values, shapes)
341-
return wc_statistics

nncf/quantization/algorithms/weight_compression/scale_estimation.py

+17
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,23 @@ def calculate_quantization_params(
371371

372372
return result_scale, zp
373373

374+
@staticmethod
375+
def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic:
376+
"""
377+
Mimic the activation reducing logic from WeightCompression.get_statistic_points.
378+
379+
:param activations: List of raw activations.
380+
:return: Instance of WCTensorStatistic class containing reduced activations and shapes.
381+
"""
382+
mean_values = []
383+
shapes = []
384+
for act in activations:
385+
shapes.append(act.shape)
386+
reduction_shape = tuple(range(act.ndim - 1))
387+
mean_values.append(fns.mean(act, axis=reduction_shape))
388+
wc_statistics = WCTensorStatistic(mean_values, shapes)
389+
return wc_statistics
390+
374391

375392
def get_target_zero_mask(compressed_weights: Tensor, zp: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
376393
"""

0 commit comments

Comments
 (0)