|
18 | 18 | from nncf.common.graph import NNCFNode
|
19 | 19 | from nncf.common.logging.track_progress import track
|
20 | 20 | from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
|
21 |
| -from nncf.common.tensor_statistics.statistics import WCTensorStatistic |
22 | 21 | from nncf.common.utils.backend import BackendType
|
23 | 22 | from nncf.common.utils.backend import get_backend
|
24 | 23 | from nncf.parameters import CompressWeightsMode
|
@@ -266,7 +265,7 @@ def _quantize_weights(
|
266 | 265 | else:
|
267 | 266 | if self._scale_estimation and block_compression_config.num_bits == 4:
|
268 | 267 | activations = [inp.squeeze()[:, (i1 + i) : (i1 + i + group_size)] for inp in inputs]
|
269 |
| - wc_statistics = self._activations_to_wc_statistics(activations) |
| 268 | + wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) |
270 | 269 | scale, zero_point = ScaleEstimation.calculate_quantization_params(
|
271 | 270 | self._backend_entity,
|
272 | 271 | wc_statistics,
|
@@ -327,15 +326,3 @@ def _quantize_weights(
|
327 | 326 | else:
|
328 | 327 | zero_points = None
|
329 | 328 | return scales, zero_points
|
330 |
| - |
331 |
| - @staticmethod |
332 |
| - def _activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic: |
333 |
| - # The code below mimics the logic from WeightCompression.get_statistic_points |
334 |
| - mean_values = [] |
335 |
| - shapes = [] |
336 |
| - for act in activations: |
337 |
| - shapes.append(act.shape) |
338 |
| - reduction_shape = tuple(range(act.ndim - 1)) |
339 |
| - mean_values.append(fns.mean(act, axis=reduction_shape)) |
340 |
| - wc_statistics = WCTensorStatistic(mean_values, shapes) |
341 |
| - return wc_statistics |
0 commit comments