|
| 1 | +# Copyright (c) 2024 Intel Corporation |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from collections import OrderedDict |
| 13 | +from copy import deepcopy |
| 14 | +from typing import Callable, Dict, List, Tuple |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import torch |
| 18 | + |
| 19 | +import nncf |
| 20 | +from nncf.common.graph.layer_attributes import WeightedLayerAttributes |
| 21 | +from nncf.common.quantization.initialization.range import RangeInitCollectorParams |
| 22 | +from nncf.common.quantization.initialization.range import RangeInitConfig |
| 23 | +from nncf.common.quantization.initialization.range import RangeInitParams |
| 24 | +from nncf.common.quantization.quantizer_setup import QuantizationPointBase |
| 25 | +from nncf.common.quantization.quantizer_setup import QuantizerSetupBase |
| 26 | +from nncf.common.quantization.structs import NonWeightQuantizerId |
| 27 | +from nncf.common.quantization.structs import QuantizationScheme |
| 28 | +from nncf.common.quantization.structs import QuantizerGroup |
| 29 | +from nncf.common.quantization.structs import QuantizerId |
| 30 | +from nncf.common.quantization.structs import WeightQuantizerId |
| 31 | +from nncf.common.scopes import should_consider_scope |
| 32 | +from nncf.common.tensor_statistics.collectors import ReductionAxes |
| 33 | +from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase |
| 34 | +from nncf.config.schemata.algo.quantization import RANGE_INIT_TYPES_VS_DESCRIPTIONS |
| 35 | +from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes |
| 36 | +from nncf.torch.graph.graph import PTNNCFGraph |
| 37 | +from nncf.torch.initialization import DataLoaderBaseRunner |
| 38 | +from nncf.torch.nncf_network import NNCFNetwork |
| 39 | +from nncf.torch.quantization.layers import BaseQuantizer |
| 40 | +from nncf.torch.quantization.layers import SymmetricQuantizer |
| 41 | +from nncf.torch.quantization.layers import get_scale_shape |
| 42 | +from nncf.torch.quantization.translator import PTTargetPointTranslator |
| 43 | +from nncf.torch.tensor_statistics.algo import TensorStatisticObservationPoint |
| 44 | +from nncf.torch.tensor_statistics.algo import create_register_input_hook |
| 45 | +from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector |
| 46 | +from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector |
| 47 | +from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector |
| 48 | +from nncf.torch.tensor_statistics.collectors import get_mixed_min_max_statistic_collector |
| 49 | +from nncf.torch.tensor_statistics.collectors import get_percentile_tensor_collector |
| 50 | +from nncf.torch.tensor_statistics.statistics import pt_convert_stat_to_min_max_tensor_stat |
| 51 | + |
| 52 | + |
| 53 | +class PTRangeInitParams(RangeInitParams): |
| 54 | + def get_max_num_init_steps(self) -> int: |
| 55 | + steps = [] |
| 56 | + if self.global_init_config is not None: |
| 57 | + steps.append(self.global_init_config.num_init_samples) |
| 58 | + for pl_config in self.per_layer_range_init_configs: |
| 59 | + steps.append(pl_config.num_init_samples) |
| 60 | + batch_size = self.init_range_data_loader.batch_size |
| 61 | + return int(np.ceil(max(steps) / batch_size)) |
| 62 | + |
| 63 | + def get_init_config_for_quantization_point(self, qp: QuantizationPointBase) -> RangeInitConfig: |
| 64 | + if qp.is_weight_quantization_point(): |
| 65 | + qid = WeightQuantizerId(qp.insertion_point.target_node_name) |
| 66 | + group = QuantizerGroup.WEIGHTS |
| 67 | + else: |
| 68 | + qid = NonWeightQuantizerId(qp.insertion_point.target_node_name, qp.insertion_point.input_port_id) |
| 69 | + group = QuantizerGroup.ACTIVATIONS |
| 70 | + return self.get_init_config_for_scope_and_group(qid, group) |
| 71 | + |
| 72 | + def get_init_config_for_scope_and_group(self, qid: QuantizerId, group: QuantizerGroup) -> RangeInitConfig: |
| 73 | + matches: List[RangeInitConfig] = [] |
| 74 | + for pl_config in self.per_layer_range_init_configs: |
| 75 | + should_be_considered = should_consider_scope(qid, pl_config.ignored_scopes, pl_config.target_scopes) |
| 76 | + if should_be_considered and (group == pl_config.target_group or pl_config.target_group is None): |
| 77 | + matches.append( |
| 78 | + RangeInitConfig( |
| 79 | + pl_config.init_type, pl_config.num_init_samples, pl_config.init_type_specific_params |
| 80 | + ) |
| 81 | + ) |
| 82 | + if len(matches) > 1: |
| 83 | + raise ValueError( |
| 84 | + "Location {} matches more than one per-layer initialization parameter definition!".format(str(qid)) |
| 85 | + ) |
| 86 | + if len(matches) == 1: |
| 87 | + return matches[0] |
| 88 | + if not matches and self.global_init_config is not None: |
| 89 | + return deepcopy(self.global_init_config) |
| 90 | + |
| 91 | + raise ValueError( |
| 92 | + "Location {} does not match any per-layer initialization parameter definition!".format(str(qid)) |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +class PTRangeInitCollectorParams(RangeInitCollectorParams): |
| 97 | + def __init__( |
| 98 | + self, is_weights: bool, scheme: QuantizationScheme, per_channel: bool, input_shape: tuple, channel_idx: int |
| 99 | + ): |
| 100 | + """ |
| 101 | + :param is_weights: Boolean that defines tensor type. True for Weights, False for Activations. |
| 102 | + :param scheme: Quantization scheme: symmetric or asymmetric. |
| 103 | + :param input_shape: Shape of the input tensor. |
| 104 | + :param channel_idx: Channel dimension. |
| 105 | + """ |
| 106 | + super().__init__(is_weights, scheme, per_channel) |
| 107 | + self._input_shape = input_shape |
| 108 | + self._channel_idx = channel_idx |
| 109 | + |
| 110 | + def get_reduction_aggregation_axes(self, is_per_sample: bool) -> Tuple[ReductionAxes, AggregationAxes]: |
| 111 | + if self.is_per_channel: |
| 112 | + return super().get_reduction_aggregation_axes(self._input_shape, (self._channel_idx,), is_per_sample) |
| 113 | + return super().get_reduction_aggregation_axes(self._input_shape, (), is_per_sample) |
| 114 | + |
| 115 | + |
| 116 | +class StatCollectorGenerator: |
| 117 | + @staticmethod |
| 118 | + def generate_collectors_for_range_init_statistics_collection( |
| 119 | + target_model_graph: PTNNCFGraph, quantizer_setup: QuantizerSetupBase, range_init_params: PTRangeInitParams |
| 120 | + ) -> Dict[TensorStatisticObservationPoint, Dict[ReductionAxes, TensorStatisticCollectorBase]]: |
| 121 | + retval = {} |
| 122 | + for qp in quantizer_setup.quantization_points.values(): |
| 123 | + init_config = range_init_params.get_init_config_for_quantization_point(qp) |
| 124 | + is_weights = qp.is_weight_quantization_point() |
| 125 | + num_batches = int( |
| 126 | + np.ceil(init_config.num_init_samples / range_init_params.init_range_data_loader.batch_size) |
| 127 | + ) |
| 128 | + if is_weights: |
| 129 | + # No need to store extra statistics in memory since weights won't change during range init |
| 130 | + num_batches = 1 |
| 131 | + |
| 132 | + tp = PTTargetPointTranslator.translate(qp.insertion_point) |
| 133 | + scale_shapes_vs_params = StatCollectorGenerator.get_all_scale_shapes_with_params(qp, target_model_graph) |
| 134 | + |
| 135 | + obs_p = TensorStatisticObservationPoint(tp, reduction_shapes=set(scale_shapes_vs_params.keys())) |
| 136 | + |
| 137 | + retval[obs_p] = {} |
| 138 | + for scale_shape in obs_p.reduction_shapes: |
| 139 | + collector_params = scale_shapes_vs_params[scale_shape] |
| 140 | + collector = StatCollectorGenerator.generate_stat_collector_for_range_init_config( |
| 141 | + init_config, scale_shape, collector_params, num_samples_to_collect_override=num_batches |
| 142 | + ) |
| 143 | + retval[obs_p][scale_shape] = collector |
| 144 | + |
| 145 | + return retval |
| 146 | + |
| 147 | + @staticmethod |
| 148 | + def generate_stat_collector_for_range_init_config( |
| 149 | + init_config: RangeInitConfig, |
| 150 | + scale_shape: ReductionAxes = None, |
| 151 | + collector_params: PTRangeInitCollectorParams = None, |
| 152 | + num_samples_to_collect_override: int = None, |
| 153 | + ) -> TensorStatisticCollectorBase: |
| 154 | + num_samples = init_config.num_init_samples |
| 155 | + if num_samples_to_collect_override is not None: |
| 156 | + num_samples = num_samples_to_collect_override |
| 157 | + if init_config.init_type not in RANGE_INIT_TYPES_VS_DESCRIPTIONS: |
| 158 | + raise nncf.InternalError("Unknown range init type: {}".format(init_config.init_type)) |
| 159 | + |
| 160 | + use_per_sample_stats = collector_params.use_per_sample_stats(init_config.init_type == "mixed_min_max") |
| 161 | + reduction_axes, aggregation_axes = collector_params.get_reduction_aggregation_axes(use_per_sample_stats) |
| 162 | + if init_config.init_type == "min_max": |
| 163 | + return get_min_max_statistic_collector( |
| 164 | + use_abs_max=collector_params.use_abs_max, |
| 165 | + reduction_axes=reduction_axes, |
| 166 | + aggregation_axes=aggregation_axes, |
| 167 | + scale_shape=scale_shape, |
| 168 | + num_samples=num_samples, |
| 169 | + ) |
| 170 | + if init_config.init_type == "mixed_min_max": |
| 171 | + return get_mixed_min_max_statistic_collector( |
| 172 | + use_abs_max=collector_params.use_abs_max, |
| 173 | + reduction_axes=reduction_axes, |
| 174 | + aggregation_axes=aggregation_axes, |
| 175 | + scale_shape=scale_shape, |
| 176 | + use_means_of_mins=collector_params.use_means_of_mins, |
| 177 | + use_means_of_maxs=collector_params.use_means_of_maxs, |
| 178 | + num_samples=num_samples, |
| 179 | + ) |
| 180 | + if init_config.init_type == "mean_min_max": |
| 181 | + return get_mixed_min_max_statistic_collector( |
| 182 | + use_abs_max=collector_params.use_abs_max, |
| 183 | + reduction_axes=reduction_axes, |
| 184 | + aggregation_axes=aggregation_axes, |
| 185 | + scale_shape=scale_shape, |
| 186 | + use_means_of_mins=True, |
| 187 | + use_means_of_maxs=True, |
| 188 | + num_samples=num_samples, |
| 189 | + ) |
| 190 | + if init_config.init_type == "threesigma": |
| 191 | + return get_median_mad_statistic_collector( |
| 192 | + reduction_axes=reduction_axes, |
| 193 | + aggregation_axes=aggregation_axes, |
| 194 | + scale_shape=scale_shape, |
| 195 | + num_samples=num_samples, |
| 196 | + ) |
| 197 | + if init_config.init_type == "percentile": |
| 198 | + min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1) |
| 199 | + max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9) |
| 200 | + return get_percentile_tensor_collector( |
| 201 | + percentiles_to_collect=(min_percentile, max_percentile), |
| 202 | + reduction_axes=reduction_axes, |
| 203 | + aggregation_axes=aggregation_axes, |
| 204 | + scale_shape=scale_shape, |
| 205 | + num_samples=num_samples, |
| 206 | + ) |
| 207 | + |
| 208 | + if init_config.init_type == "mean_percentile": |
| 209 | + min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1) |
| 210 | + max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9) |
| 211 | + return get_mean_percentile_statistic_collector( |
| 212 | + percentiles_to_collect=(min_percentile, max_percentile), |
| 213 | + reduction_axes=reduction_axes, |
| 214 | + aggregation_axes=aggregation_axes, |
| 215 | + scale_shape=scale_shape, |
| 216 | + num_samples=num_samples, |
| 217 | + ) |
| 218 | + raise ValueError("Range init type not handled!") |
| 219 | + |
| 220 | + @classmethod |
| 221 | + def get_all_scale_shapes_with_params( |
| 222 | + cls, qp: QuantizationPointBase, target_nncf_graph: PTNNCFGraph |
| 223 | + ) -> Dict[ReductionAxes, PTRangeInitCollectorParams]: |
| 224 | + qconfigs = qp.get_all_configs_list() |
| 225 | + if qp.is_weight_quantization_point(): |
| 226 | + module_node = target_nncf_graph.get_node_by_name(qp.insertion_point.target_node_name) |
| 227 | + layer_attributes = module_node.layer_attributes |
| 228 | + assert isinstance(layer_attributes, WeightedLayerAttributes) |
| 229 | + input_shape = layer_attributes.get_weight_shape() |
| 230 | + channel_idx = layer_attributes.get_target_dim_for_compression() |
| 231 | + else: |
| 232 | + input_shape = target_nncf_graph.get_input_shape_for_insertion_point(qp.insertion_point) |
| 233 | + channel_idx = 1 # channel dim for activations |
| 234 | + |
| 235 | + retval = {} |
| 236 | + for qconfig in qconfigs: |
| 237 | + is_weights = qp.is_weight_quantization_point() |
| 238 | + scale_shape = tuple( |
| 239 | + get_scale_shape( |
| 240 | + input_shape, is_weights=is_weights, per_channel=qconfig.per_channel, channel_idx=channel_idx |
| 241 | + ) |
| 242 | + ) |
| 243 | + |
| 244 | + if scale_shape not in retval: |
| 245 | + retval[scale_shape] = PTRangeInitCollectorParams( |
| 246 | + is_weights, qconfig.mode, qconfig.per_channel, input_shape, channel_idx |
| 247 | + ) |
| 248 | + return retval |
| 249 | + |
| 250 | + |
| 251 | +class DataLoaderRangeInitializeRunner(DataLoaderBaseRunner): |
| 252 | + def __init__( |
| 253 | + self, |
| 254 | + model: NNCFNetwork, |
| 255 | + modules_to_init_vs_init_configs: Dict[str, Tuple[BaseQuantizer, RangeInitConfig, bool, Tuple[int]]], |
| 256 | + init_device: str, |
| 257 | + batch_size: int = None, |
| 258 | + ): |
| 259 | + super().__init__(model, init_device) |
| 260 | + self.modules_to_init = modules_to_init_vs_init_configs |
| 261 | + self.progressbar_description = "Range parameters initialization" |
| 262 | + |
| 263 | + self.collectors_and_modules_to_init: Dict[str, Tuple[TensorStatisticCollectorBase, BaseQuantizer]] = ( |
| 264 | + OrderedDict() |
| 265 | + ) |
| 266 | + self.hook_handles = [] |
| 267 | + self.batch_size = batch_size |
| 268 | + |
| 269 | + def _get_fwd_hook( |
| 270 | + self, collector: TensorStatisticCollectorBase |
| 271 | + ) -> Callable[["torch.Module", torch.Tensor, torch.Tensor], torch.Tensor]: |
| 272 | + hook = create_register_input_hook(collector=collector) |
| 273 | + |
| 274 | + def fwd_hook(module, input_, output): |
| 275 | + hook(input_[0]) |
| 276 | + |
| 277 | + return fwd_hook |
| 278 | + |
| 279 | + def _prepare_initialization(self): |
| 280 | + for name, data in self.modules_to_init.items(): |
| 281 | + quantizer_module, init_config, is_weights, input_shape = data |
| 282 | + num_samples_override = None |
| 283 | + if self.batch_size is not None: |
| 284 | + num_batches = np.ceil(init_config.num_init_samples / self.batch_size) |
| 285 | + num_samples_override = num_batches |
| 286 | + |
| 287 | + if isinstance(quantizer_module, SymmetricQuantizer): |
| 288 | + mode = QuantizationScheme.SYMMETRIC |
| 289 | + else: |
| 290 | + mode = QuantizationScheme.ASYMMETRIC |
| 291 | + |
| 292 | + shape = quantizer_module.scale_shape |
| 293 | + if shape == (1,): # Per-tensor |
| 294 | + channel_idx = None |
| 295 | + elif len(shape) > 1 and all(item == 1 for item in shape): |
| 296 | + channel_idx = 0 # (1, 1, 1, 1) - doest not matter which dim is channel_idx |
| 297 | + else: |
| 298 | + if not is_weights: |
| 299 | + channel_idx = 1 # channel dim for activations |
| 300 | + else: |
| 301 | + channel_idx = [i for i, val in enumerate(shape) if val != 1][0] |
| 302 | + |
| 303 | + collector_params = PTRangeInitCollectorParams( |
| 304 | + is_weights, mode, quantizer_module.per_channel, input_shape, channel_idx |
| 305 | + ) |
| 306 | + |
| 307 | + collector = StatCollectorGenerator.generate_stat_collector_for_range_init_config( |
| 308 | + init_config, tuple(quantizer_module.scale_shape), collector_params, num_samples_override |
| 309 | + ) |
| 310 | + |
| 311 | + self.collectors_and_modules_to_init[name] = collector, quantizer_module |
| 312 | + |
| 313 | + self.hook_handles.append(quantizer_module.register_forward_hook(self._get_fwd_hook(collector))) |
| 314 | + |
| 315 | + def _apply_initializers(self): |
| 316 | + for handle in self.hook_handles: |
| 317 | + handle.remove() |
| 318 | + for scope_str, collector_and_module in self.collectors_and_modules_to_init.items(): |
| 319 | + collector, quantizer_module = collector_and_module |
| 320 | + target_stat = collector.get_statistics() |
| 321 | + minmax_stats = pt_convert_stat_to_min_max_tensor_stat(target_stat) |
| 322 | + quantizer_module.apply_minmax_init( |
| 323 | + minmax_stats.min_values.data, minmax_stats.max_values.data, log_module_name=scope_str |
| 324 | + ) |
0 commit comments