Skip to content

Commit bac19a7

Browse files
andreyanufrljaljushkin
authored andcommitted
Added possibility to generate base text on GPU for text evaluation. (openvinotoolkit#1945)
Convert input device type according to device type of model for text evaluator.
1 parent 3ce470c commit bac19a7

File tree

3 files changed

+553
-2
lines changed

3 files changed

+553
-2
lines changed

init_range.py

+324
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from collections import OrderedDict
13+
from copy import deepcopy
14+
from typing import Callable, Dict, List, Tuple
15+
16+
import numpy as np
17+
import torch
18+
19+
import nncf
20+
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
21+
from nncf.common.quantization.initialization.range import RangeInitCollectorParams
22+
from nncf.common.quantization.initialization.range import RangeInitConfig
23+
from nncf.common.quantization.initialization.range import RangeInitParams
24+
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
25+
from nncf.common.quantization.quantizer_setup import QuantizerSetupBase
26+
from nncf.common.quantization.structs import NonWeightQuantizerId
27+
from nncf.common.quantization.structs import QuantizationScheme
28+
from nncf.common.quantization.structs import QuantizerGroup
29+
from nncf.common.quantization.structs import QuantizerId
30+
from nncf.common.quantization.structs import WeightQuantizerId
31+
from nncf.common.scopes import should_consider_scope
32+
from nncf.common.tensor_statistics.collectors import ReductionAxes
33+
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
34+
from nncf.config.schemata.algo.quantization import RANGE_INIT_TYPES_VS_DESCRIPTIONS
35+
from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes
36+
from nncf.torch.graph.graph import PTNNCFGraph
37+
from nncf.torch.initialization import DataLoaderBaseRunner
38+
from nncf.torch.nncf_network import NNCFNetwork
39+
from nncf.torch.quantization.layers import BaseQuantizer
40+
from nncf.torch.quantization.layers import SymmetricQuantizer
41+
from nncf.torch.quantization.layers import get_scale_shape
42+
from nncf.torch.quantization.translator import PTTargetPointTranslator
43+
from nncf.torch.tensor_statistics.algo import TensorStatisticObservationPoint
44+
from nncf.torch.tensor_statistics.algo import create_register_input_hook
45+
from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector
46+
from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector
47+
from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector
48+
from nncf.torch.tensor_statistics.collectors import get_mixed_min_max_statistic_collector
49+
from nncf.torch.tensor_statistics.collectors import get_percentile_tensor_collector
50+
from nncf.torch.tensor_statistics.statistics import pt_convert_stat_to_min_max_tensor_stat
51+
52+
53+
class PTRangeInitParams(RangeInitParams):
54+
def get_max_num_init_steps(self) -> int:
55+
steps = []
56+
if self.global_init_config is not None:
57+
steps.append(self.global_init_config.num_init_samples)
58+
for pl_config in self.per_layer_range_init_configs:
59+
steps.append(pl_config.num_init_samples)
60+
batch_size = self.init_range_data_loader.batch_size
61+
return int(np.ceil(max(steps) / batch_size))
62+
63+
def get_init_config_for_quantization_point(self, qp: QuantizationPointBase) -> RangeInitConfig:
64+
if qp.is_weight_quantization_point():
65+
qid = WeightQuantizerId(qp.insertion_point.target_node_name)
66+
group = QuantizerGroup.WEIGHTS
67+
else:
68+
qid = NonWeightQuantizerId(qp.insertion_point.target_node_name, qp.insertion_point.input_port_id)
69+
group = QuantizerGroup.ACTIVATIONS
70+
return self.get_init_config_for_scope_and_group(qid, group)
71+
72+
def get_init_config_for_scope_and_group(self, qid: QuantizerId, group: QuantizerGroup) -> RangeInitConfig:
73+
matches: List[RangeInitConfig] = []
74+
for pl_config in self.per_layer_range_init_configs:
75+
should_be_considered = should_consider_scope(qid, pl_config.ignored_scopes, pl_config.target_scopes)
76+
if should_be_considered and (group == pl_config.target_group or pl_config.target_group is None):
77+
matches.append(
78+
RangeInitConfig(
79+
pl_config.init_type, pl_config.num_init_samples, pl_config.init_type_specific_params
80+
)
81+
)
82+
if len(matches) > 1:
83+
raise ValueError(
84+
"Location {} matches more than one per-layer initialization parameter definition!".format(str(qid))
85+
)
86+
if len(matches) == 1:
87+
return matches[0]
88+
if not matches and self.global_init_config is not None:
89+
return deepcopy(self.global_init_config)
90+
91+
raise ValueError(
92+
"Location {} does not match any per-layer initialization parameter definition!".format(str(qid))
93+
)
94+
95+
96+
class PTRangeInitCollectorParams(RangeInitCollectorParams):
97+
def __init__(
98+
self, is_weights: bool, scheme: QuantizationScheme, per_channel: bool, input_shape: tuple, channel_idx: int
99+
):
100+
"""
101+
:param is_weights: Boolean that defines tensor type. True for Weights, False for Activations.
102+
:param scheme: Quantization scheme: symmetric or asymmetric.
103+
:param input_shape: Shape of the input tensor.
104+
:param channel_idx: Channel dimension.
105+
"""
106+
super().__init__(is_weights, scheme, per_channel)
107+
self._input_shape = input_shape
108+
self._channel_idx = channel_idx
109+
110+
def get_reduction_aggregation_axes(self, is_per_sample: bool) -> Tuple[ReductionAxes, AggregationAxes]:
111+
if self.is_per_channel:
112+
return super().get_reduction_aggregation_axes(self._input_shape, (self._channel_idx,), is_per_sample)
113+
return super().get_reduction_aggregation_axes(self._input_shape, (), is_per_sample)
114+
115+
116+
class StatCollectorGenerator:
117+
@staticmethod
118+
def generate_collectors_for_range_init_statistics_collection(
119+
target_model_graph: PTNNCFGraph, quantizer_setup: QuantizerSetupBase, range_init_params: PTRangeInitParams
120+
) -> Dict[TensorStatisticObservationPoint, Dict[ReductionAxes, TensorStatisticCollectorBase]]:
121+
retval = {}
122+
for qp in quantizer_setup.quantization_points.values():
123+
init_config = range_init_params.get_init_config_for_quantization_point(qp)
124+
is_weights = qp.is_weight_quantization_point()
125+
num_batches = int(
126+
np.ceil(init_config.num_init_samples / range_init_params.init_range_data_loader.batch_size)
127+
)
128+
if is_weights:
129+
# No need to store extra statistics in memory since weights won't change during range init
130+
num_batches = 1
131+
132+
tp = PTTargetPointTranslator.translate(qp.insertion_point)
133+
scale_shapes_vs_params = StatCollectorGenerator.get_all_scale_shapes_with_params(qp, target_model_graph)
134+
135+
obs_p = TensorStatisticObservationPoint(tp, reduction_shapes=set(scale_shapes_vs_params.keys()))
136+
137+
retval[obs_p] = {}
138+
for scale_shape in obs_p.reduction_shapes:
139+
collector_params = scale_shapes_vs_params[scale_shape]
140+
collector = StatCollectorGenerator.generate_stat_collector_for_range_init_config(
141+
init_config, scale_shape, collector_params, num_samples_to_collect_override=num_batches
142+
)
143+
retval[obs_p][scale_shape] = collector
144+
145+
return retval
146+
147+
@staticmethod
148+
def generate_stat_collector_for_range_init_config(
149+
init_config: RangeInitConfig,
150+
scale_shape: ReductionAxes = None,
151+
collector_params: PTRangeInitCollectorParams = None,
152+
num_samples_to_collect_override: int = None,
153+
) -> TensorStatisticCollectorBase:
154+
num_samples = init_config.num_init_samples
155+
if num_samples_to_collect_override is not None:
156+
num_samples = num_samples_to_collect_override
157+
if init_config.init_type not in RANGE_INIT_TYPES_VS_DESCRIPTIONS:
158+
raise nncf.InternalError("Unknown range init type: {}".format(init_config.init_type))
159+
160+
use_per_sample_stats = collector_params.use_per_sample_stats(init_config.init_type == "mixed_min_max")
161+
reduction_axes, aggregation_axes = collector_params.get_reduction_aggregation_axes(use_per_sample_stats)
162+
if init_config.init_type == "min_max":
163+
return get_min_max_statistic_collector(
164+
use_abs_max=collector_params.use_abs_max,
165+
reduction_axes=reduction_axes,
166+
aggregation_axes=aggregation_axes,
167+
scale_shape=scale_shape,
168+
num_samples=num_samples,
169+
)
170+
if init_config.init_type == "mixed_min_max":
171+
return get_mixed_min_max_statistic_collector(
172+
use_abs_max=collector_params.use_abs_max,
173+
reduction_axes=reduction_axes,
174+
aggregation_axes=aggregation_axes,
175+
scale_shape=scale_shape,
176+
use_means_of_mins=collector_params.use_means_of_mins,
177+
use_means_of_maxs=collector_params.use_means_of_maxs,
178+
num_samples=num_samples,
179+
)
180+
if init_config.init_type == "mean_min_max":
181+
return get_mixed_min_max_statistic_collector(
182+
use_abs_max=collector_params.use_abs_max,
183+
reduction_axes=reduction_axes,
184+
aggregation_axes=aggregation_axes,
185+
scale_shape=scale_shape,
186+
use_means_of_mins=True,
187+
use_means_of_maxs=True,
188+
num_samples=num_samples,
189+
)
190+
if init_config.init_type == "threesigma":
191+
return get_median_mad_statistic_collector(
192+
reduction_axes=reduction_axes,
193+
aggregation_axes=aggregation_axes,
194+
scale_shape=scale_shape,
195+
num_samples=num_samples,
196+
)
197+
if init_config.init_type == "percentile":
198+
min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1)
199+
max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9)
200+
return get_percentile_tensor_collector(
201+
percentiles_to_collect=(min_percentile, max_percentile),
202+
reduction_axes=reduction_axes,
203+
aggregation_axes=aggregation_axes,
204+
scale_shape=scale_shape,
205+
num_samples=num_samples,
206+
)
207+
208+
if init_config.init_type == "mean_percentile":
209+
min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1)
210+
max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9)
211+
return get_mean_percentile_statistic_collector(
212+
percentiles_to_collect=(min_percentile, max_percentile),
213+
reduction_axes=reduction_axes,
214+
aggregation_axes=aggregation_axes,
215+
scale_shape=scale_shape,
216+
num_samples=num_samples,
217+
)
218+
raise ValueError("Range init type not handled!")
219+
220+
@classmethod
221+
def get_all_scale_shapes_with_params(
222+
cls, qp: QuantizationPointBase, target_nncf_graph: PTNNCFGraph
223+
) -> Dict[ReductionAxes, PTRangeInitCollectorParams]:
224+
qconfigs = qp.get_all_configs_list()
225+
if qp.is_weight_quantization_point():
226+
module_node = target_nncf_graph.get_node_by_name(qp.insertion_point.target_node_name)
227+
layer_attributes = module_node.layer_attributes
228+
assert isinstance(layer_attributes, WeightedLayerAttributes)
229+
input_shape = layer_attributes.get_weight_shape()
230+
channel_idx = layer_attributes.get_target_dim_for_compression()
231+
else:
232+
input_shape = target_nncf_graph.get_input_shape_for_insertion_point(qp.insertion_point)
233+
channel_idx = 1 # channel dim for activations
234+
235+
retval = {}
236+
for qconfig in qconfigs:
237+
is_weights = qp.is_weight_quantization_point()
238+
scale_shape = tuple(
239+
get_scale_shape(
240+
input_shape, is_weights=is_weights, per_channel=qconfig.per_channel, channel_idx=channel_idx
241+
)
242+
)
243+
244+
if scale_shape not in retval:
245+
retval[scale_shape] = PTRangeInitCollectorParams(
246+
is_weights, qconfig.mode, qconfig.per_channel, input_shape, channel_idx
247+
)
248+
return retval
249+
250+
251+
class DataLoaderRangeInitializeRunner(DataLoaderBaseRunner):
252+
def __init__(
253+
self,
254+
model: NNCFNetwork,
255+
modules_to_init_vs_init_configs: Dict[str, Tuple[BaseQuantizer, RangeInitConfig, bool, Tuple[int]]],
256+
init_device: str,
257+
batch_size: int = None,
258+
):
259+
super().__init__(model, init_device)
260+
self.modules_to_init = modules_to_init_vs_init_configs
261+
self.progressbar_description = "Range parameters initialization"
262+
263+
self.collectors_and_modules_to_init: Dict[str, Tuple[TensorStatisticCollectorBase, BaseQuantizer]] = (
264+
OrderedDict()
265+
)
266+
self.hook_handles = []
267+
self.batch_size = batch_size
268+
269+
def _get_fwd_hook(
270+
self, collector: TensorStatisticCollectorBase
271+
) -> Callable[["torch.Module", torch.Tensor, torch.Tensor], torch.Tensor]:
272+
hook = create_register_input_hook(collector=collector)
273+
274+
def fwd_hook(module, input_, output):
275+
hook(input_[0])
276+
277+
return fwd_hook
278+
279+
def _prepare_initialization(self):
280+
for name, data in self.modules_to_init.items():
281+
quantizer_module, init_config, is_weights, input_shape = data
282+
num_samples_override = None
283+
if self.batch_size is not None:
284+
num_batches = np.ceil(init_config.num_init_samples / self.batch_size)
285+
num_samples_override = num_batches
286+
287+
if isinstance(quantizer_module, SymmetricQuantizer):
288+
mode = QuantizationScheme.SYMMETRIC
289+
else:
290+
mode = QuantizationScheme.ASYMMETRIC
291+
292+
shape = quantizer_module.scale_shape
293+
if shape == (1,): # Per-tensor
294+
channel_idx = None
295+
elif len(shape) > 1 and all(item == 1 for item in shape):
296+
channel_idx = 0 # (1, 1, 1, 1) - doest not matter which dim is channel_idx
297+
else:
298+
if not is_weights:
299+
channel_idx = 1 # channel dim for activations
300+
else:
301+
channel_idx = [i for i, val in enumerate(shape) if val != 1][0]
302+
303+
collector_params = PTRangeInitCollectorParams(
304+
is_weights, mode, quantizer_module.per_channel, input_shape, channel_idx
305+
)
306+
307+
collector = StatCollectorGenerator.generate_stat_collector_for_range_init_config(
308+
init_config, tuple(quantizer_module.scale_shape), collector_params, num_samples_override
309+
)
310+
311+
self.collectors_and_modules_to_init[name] = collector, quantizer_module
312+
313+
self.hook_handles.append(quantizer_module.register_forward_hook(self._get_fwd_hook(collector)))
314+
315+
def _apply_initializers(self):
316+
for handle in self.hook_handles:
317+
handle.remove()
318+
for scope_str, collector_and_module in self.collectors_and_modules_to_init.items():
319+
collector, quantizer_module = collector_and_module
320+
target_stat = collector.get_statistics()
321+
minmax_stats = pt_convert_stat_to_min_max_tensor_stat(target_stat)
322+
quantizer_module.apply_minmax_init(
323+
minmax_stats.min_values.data, minmax_stats.max_values.data, log_module_name=scope_str
324+
)

0 commit comments

Comments
 (0)