|
16 | 16 | from typing import Callable, Dict, List, Tuple, Union
|
17 | 17 |
|
18 | 18 | import torch
|
19 |
| -import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq |
20 |
| -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver |
21 |
| -from torch.ao.quantization.quantizer import QuantizationSpec |
22 |
| -from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer |
23 | 19 | from typing_extensions import TypeAlias
|
24 | 20 |
|
25 | 21 | from neural_compressor.common import logger
|
@@ -120,11 +116,9 @@ def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) ->
|
120 | 116 | return filter_result
|
121 | 117 |
|
122 | 118 |
|
123 |
| -def get_double_quant_config(double_quant_type): |
| 119 | +def get_double_quant_config_dict(double_quant_type="BNB_NF4"): |
124 | 120 | from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS
|
125 | 121 |
|
126 |
| - if double_quant_type is None: |
127 |
| - return {} |
128 | 122 | assert double_quant_type in DOUBLE_QUANT_CONFIGS, "Supported double quant configs: {}".format(
|
129 | 123 | list(DOUBLE_QUANT_CONFIGS.keys())
|
130 | 124 | )
|
@@ -170,61 +164,3 @@ def postprocess_model(model, mode, quantizer):
|
170 | 164 | elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
|
171 | 165 | if getattr(model, "quantizer", False):
|
172 | 166 | del model.quantizer
|
173 |
| - |
174 |
| - |
175 |
| -def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec: |
176 |
| - dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8} |
177 |
| - select_dtype = dtype_mapping[dtype] |
178 |
| - min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)} |
179 |
| - qscheme_mapping = { |
180 |
| - "per_channel": {True: torch.per_channel_symmetric, False: torch.per_tensor_affine}, |
181 |
| - "per_tensor": {True: torch.per_tensor_symmetric, False: torch.per_tensor_affine}, |
182 |
| - } |
183 |
| - observer_mapping = { |
184 |
| - "placeholder": PlaceholderObserver, |
185 |
| - "minmax": MinMaxObserver, |
186 |
| - "kl": HistogramObserver, |
187 |
| - } |
188 |
| - # Force to use placeholder observer for dynamic quantization |
189 |
| - if is_dynamic: |
190 |
| - algo = "placeholder" |
191 |
| - # algo |
192 |
| - observer_or_fake_quant_ctr = observer_mapping[algo] |
193 |
| - # qscheme |
194 |
| - qscheme = qscheme_mapping[granularity][sym] |
195 |
| - quantization_spec = QuantizationSpec( |
196 |
| - dtype=select_dtype, |
197 |
| - quant_min=min_max_mapping[select_dtype][0], |
198 |
| - quant_max=min_max_mapping[select_dtype][1], |
199 |
| - observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, |
200 |
| - qscheme=qscheme, |
201 |
| - is_dynamic=is_dynamic, |
202 |
| - ) |
203 |
| - return quantization_spec |
204 |
| - |
205 |
| - |
206 |
| -def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig: |
207 |
| - default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic) |
208 |
| - input_act_quant_spec = create_quant_spec_from_config( |
209 |
| - inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic |
210 |
| - ) |
211 |
| - weight_quant_spec = create_quant_spec_from_config( |
212 |
| - inc_config.w_dtype, inc_config.w_sym, inc_config.w_granularity, inc_config.w_algo |
213 |
| - ) |
214 |
| - quant_config = QuantizationConfig( |
215 |
| - input_activation=input_act_quant_spec, |
216 |
| - output_activation=default_quant_config.output_activation, |
217 |
| - weight=weight_quant_spec, |
218 |
| - bias=default_quant_config.bias, |
219 |
| - is_qat=False, |
220 |
| - ) |
221 |
| - return quant_config |
222 |
| - |
223 |
| - |
224 |
| -def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer: |
225 |
| - quantizer = xiq.X86InductorQuantizer() |
226 |
| - # set global |
227 |
| - global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic) |
228 |
| - quantizer.set_global(global_config) |
229 |
| - # Skip the local config for now (need torch 2.4) |
230 |
| - return quantizer |
0 commit comments