Skip to content

Commit 64c2f34

Browse files
Integrate AutoRound v0.4 [3x] (#2072)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a9958d0 commit 64c2f34

File tree

10 files changed

+321
-73
lines changed

10 files changed

+321
-73
lines changed

neural_compressor/torch/algorithms/weight_only/autoround.py

+193-37
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def _is_auto_round_available():
3333

3434
_is_auto_round_available()
3535

36-
from auto_round import AutoRound # pylint: disable=E0401
36+
from auto_round import AutoRound, AutoRoundMLLM # pylint: disable=E0401
3737
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
38+
from auto_round.mllm.template import Template, get_template
3839

3940
from neural_compressor.torch.algorithms import Quantizer
4041
from neural_compressor.torch.utils import get_accelerator, logger
@@ -70,13 +71,24 @@ def __init__(
7071
dynamic_max_gap: int = -1,
7172
data_type: str = "int",
7273
scale_dtype: str = "fp16",
73-
quant_block_list: list = None,
74+
to_quant_block_names: list = None,
7475
act_bits: int = 32,
7576
act_group_size: int = None,
7677
act_sym: bool = None,
7778
act_dynamic: bool = True,
7879
low_cpu_mem_usage: bool = False,
7980
export_format: str = "itrex",
81+
# v0.4
82+
enable_norm_bias_tuning: bool = False,
83+
enable_torch_compile: bool = None,
84+
# mllm
85+
is_mllm: bool = False,
86+
quant_nontext_module: Union[str, list] = None,
87+
extra_data_dir: str = None,
88+
image_processor=None,
89+
processor=None,
90+
template: Union[str, Template] = None,
91+
truncation: bool = False,
8092
**kwargs,
8193
):
8294
"""Init a AutQRoundQuantizer object.
@@ -130,11 +142,23 @@ def __init__(
130142
data_type (str): The data type to be used (default is "int").
131143
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
132144
have different choices.
133-
quant_block_list (list): A list whose elements are list of block's layer names to be quantized.
145+
to_quant_block_names (list): A list whose elements are list of block's layer names to be quantized.
134146
act_bits (int): Number of bits for activation quantization. Default is 32.
135147
act_group_size (int): Group size for activation quantization. Default is None.
136148
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
137149
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
150+
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
151+
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
152+
quant_nontext_module (Union[str, list]): Whether to quantize nontext module.
153+
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
154+
extra_data_dir (str): The path for extra data such as images, audio or videos.
155+
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
156+
decode the data that groups several modalities (among text, vision and audio).
157+
This is handled by objects called processors, which group together two or more processing objects such
158+
as tokenizers (for the text modality), image processors (for vision) and feature extractors (for audio).
159+
image_processor (Processor): Image processor for special model like llava.
160+
template (Template): The template to specify process for different mllms.
161+
truncation (bool): Activates truncation to cut input sequences longer than `max_length` to `max_length`.
138162
139163
Returns:
140164
The quantized model.
@@ -162,13 +186,22 @@ def __init__(
162186
self.dynamic_max_gap = dynamic_max_gap
163187
self.data_type = data_type
164188
self.scale_dtype = scale_dtype
165-
self.quant_block_list = quant_block_list
189+
self.to_quant_block_names = to_quant_block_names
166190
self.act_bits = act_bits
167191
self.act_group_size = act_group_size
168192
self.act_sym = act_sym
169193
self.act_dynamic = act_dynamic
170194
self.low_cpu_mem_usage = low_cpu_mem_usage
171195
self.export_format = export_format
196+
self.enable_norm_bias_tuning = enable_norm_bias_tuning
197+
self.enable_torch_compile = enable_torch_compile
198+
self.is_mllm = is_mllm
199+
self.quant_nontext_module = quant_nontext_module
200+
self.extra_data_dir = extra_data_dir
201+
self.processor = processor
202+
self.image_processor = image_processor
203+
self.template = template
204+
self.truncation = truncation
172205

173206
def prepare(self, model: torch.nn.Module, *args, **kwargs):
174207
"""Prepares a given model for quantization.
@@ -193,39 +226,83 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
193226
"""
194227
dataloader = CapturedDataloader(model.args_list, model.kwargs_list)
195228
model = model.orig_model
196-
rounder = AutoRound(
197-
model=model,
198-
tokenizer=None,
199-
dataset=dataloader,
200-
layer_config=self.quant_config or {},
201-
enable_full_range=self.enable_full_range,
202-
batch_size=self.batch_size,
203-
amp=self.amp,
204-
device=self.device,
205-
lr_scheduler=self.lr_scheduler,
206-
enable_quanted_input=self.enable_quanted_input,
207-
enable_minmax_tuning=self.enable_minmax_tuning,
208-
lr=self.lr,
209-
minmax_lr=self.minmax_lr,
210-
low_gpu_mem_usage=self.low_gpu_mem_usage,
211-
iters=self.iters,
212-
seqlen=self.seqlen,
213-
nsamples=self.nsamples,
214-
sampler=self.sampler,
215-
seed=self.seed,
216-
nblocks=self.nblocks,
217-
gradient_accumulate_steps=self.gradient_accumulate_steps,
218-
not_use_best_mse=self.not_use_best_mse,
219-
dynamic_max_gap=self.dynamic_max_gap,
220-
data_type=self.data_type,
221-
scale_dtype=self.scale_dtype,
222-
quant_block_list=self.quant_block_list,
223-
act_bits=self.act_bits,
224-
act_group_size=self.act_group_size,
225-
act_sym=self.act_sym,
226-
act_dynamic=self.act_dynamic,
227-
low_cpu_mem_usage=self.low_cpu_mem_usage,
228-
)
229+
if self.is_mllm:
230+
rounder = AutoRoundMLLM(
231+
model,
232+
tokenizer=None,
233+
processor=self.processor,
234+
image_processor=self.image_processor,
235+
layer_config=self.quant_config,
236+
batch_size=self.batch_size,
237+
amp=self.amp,
238+
device=self.device,
239+
lr_scheduler=self.lr_scheduler,
240+
dataset=dataloader,
241+
extra_data_dir=self.extra_data_dir,
242+
template=self.template,
243+
quant_nontext_module=self.quant_nontext_module,
244+
enable_quanted_input=self.enable_quanted_input,
245+
enable_minmax_tuning=self.enable_minmax_tuning,
246+
lr=self.lr,
247+
minmax_lr=self.minmax_lr,
248+
low_gpu_mem_usage=self.low_gpu_mem_usage,
249+
low_cpu_mem_usage=self.low_gpu_mem_usage,
250+
iters=self.iters,
251+
seqlen=self.seqlen,
252+
nsamples=self.nsamples,
253+
sampler=self.sampler,
254+
seed=self.seed,
255+
nblocks=self.nblocks,
256+
gradient_accumulate_steps=self.gradient_accumulate_steps,
257+
not_use_best_mse=self.not_use_best_mse,
258+
dynamic_max_gap=self.dynamic_max_gap,
259+
data_type=self.data_type,
260+
scale_dtype=self.scale_dtype,
261+
act_bits=self.act_bits,
262+
act_group_size=self.act_group_size,
263+
act_sym=self.act_sym,
264+
act_dynamic=self.act_dynamic,
265+
to_quant_block_names=self.to_quant_block_names,
266+
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
267+
truncation=self.truncation,
268+
enable_torch_compile=self.enable_torch_compile,
269+
)
270+
else:
271+
rounder = AutoRound(
272+
model=model,
273+
tokenizer=None,
274+
dataset=dataloader,
275+
layer_config=self.quant_config or {},
276+
enable_full_range=self.enable_full_range,
277+
batch_size=self.batch_size,
278+
amp=self.amp,
279+
device=self.device,
280+
lr_scheduler=self.lr_scheduler,
281+
enable_quanted_input=self.enable_quanted_input,
282+
enable_minmax_tuning=self.enable_minmax_tuning,
283+
lr=self.lr,
284+
minmax_lr=self.minmax_lr,
285+
low_gpu_mem_usage=self.low_gpu_mem_usage,
286+
iters=self.iters,
287+
seqlen=self.seqlen,
288+
nsamples=self.nsamples,
289+
sampler=self.sampler,
290+
seed=self.seed,
291+
nblocks=self.nblocks,
292+
gradient_accumulate_steps=self.gradient_accumulate_steps,
293+
not_use_best_mse=self.not_use_best_mse,
294+
dynamic_max_gap=self.dynamic_max_gap,
295+
data_type=self.data_type,
296+
scale_dtype=self.scale_dtype,
297+
to_quant_block_names=self.to_quant_block_names,
298+
act_bits=self.act_bits,
299+
act_group_size=self.act_group_size,
300+
act_sym=self.act_sym,
301+
act_dynamic=self.act_dynamic,
302+
low_cpu_mem_usage=self.low_cpu_mem_usage,
303+
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
304+
enable_torch_compile=self.enable_torch_compile,
305+
)
229306
model, weight_config = rounder.quantize()
230307
model.autoround_config = weight_config
231308
if "itrex" in self.export_format:
@@ -259,3 +336,82 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
259336
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, nsamples=nsamples
260337
)
261338
return dataloader
339+
340+
341+
def get_mllm_dataloader(
342+
template,
343+
model,
344+
tokenizer,
345+
processor=None,
346+
image_processor=None,
347+
dataset="liuhaotian/llava_conv_58k",
348+
extra_data_dir=None,
349+
seqlen=512,
350+
bs=1,
351+
split=None,
352+
apply_template=None,
353+
truncation=False,
354+
seed=42,
355+
nsamples=512,
356+
gradient_accumulate_steps=1,
357+
quant_nontext_module=False,
358+
):
359+
"""Generate a DataLoader for calibration using specified parameters.
360+
361+
Args:
362+
template (Template): The template to specify process for different mllms.
363+
model (Model): The model to quantized.
364+
tokenizer (Tokenizer): The tokenizer to use for tokenization.
365+
Dataset_name (str): The name or path of the dataset.
366+
extra_data_dir (str): The path for extra data such as images, audio or videos.
367+
seqlen (int): The exact sequence length. samples < seqlen will be dropped,
368+
samples longer than seqlen will be truncated
369+
bs (int, optional): The batch size. Defaults to 4.
370+
split (str, optional): The data split to use. Defaults to None.
371+
apply_template: Whether to apply chat template in tokenization.
372+
373+
Returns:
374+
DataLoader: The DataLoader for the calibrated datasets.
375+
"""
376+
from auto_round.calib_dataset import CALIB_DATASETS
377+
from auto_round.mllm.autoround_mllm import _only_text_test
378+
from auto_round.mllm.mllm_dataset import get_mllm_dataloader # pylint: disable=E0401
379+
380+
if quant_nontext_module or (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer)):
381+
if quant_nontext_module:
382+
logger.warning(
383+
"Quantitative nontext module is not supported for plain text datasets,"
384+
"will use liuhaotian/llava_conv_58k with default config as an alternative."
385+
)
386+
else:
387+
logger.warning(
388+
f"{model.config.model_type} not support for {dataset},"
389+
" will use liuhaotian/llava_conv_58k with default config as an alternative."
390+
)
391+
dataset = "liuhaotian/llava_conv_58k"
392+
truncation = False
393+
batch_size = 1
394+
gradient_accumulate_steps = 4
395+
seqlen = 512
396+
397+
dataset = dataset.replace(" ", "")
398+
template = template if template is not None else model.config.model_type
399+
template = get_template(
400+
template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor
401+
)
402+
dataloader, batch_size, gradient_accumulate_steps = get_mllm_dataloader(
403+
template=template,
404+
model=model,
405+
tokenizer=tokenizer,
406+
image_processor=image_processor,
407+
dataset=dataset,
408+
extra_data_dir=extra_data_dir,
409+
seqlen=seqlen,
410+
bs=bs,
411+
seed=seed,
412+
truncation=truncation,
413+
nsamples=nsamples,
414+
gradient_accumulate_steps=gradient_accumulate_steps,
415+
quant_nontext_module=quant_nontext_module,
416+
)
417+
return dataloader, template, truncation, batch_size, gradient_accumulate_steps, seqlen

neural_compressor/torch/algorithms/weight_only/utility.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def _hook(module, inputs, outputs):
11721172
return total_values
11731173

11741174

1175-
class CapturedDataloader:
1175+
class CapturedDataloader(torch.utils.data.DataLoader):
11761176
def __init__(self, args_list, kwargs_list) -> None:
11771177
self.args_list = args_list
11781178
self.kwargs_list = kwargs_list

neural_compressor/torch/quantization/algorithm_entry.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,18 @@ def autoround_quantize_entry(
607607
not_use_best_mse = quant_config.not_use_best_mse
608608
dynamic_max_gap = quant_config.dynamic_max_gap
609609
scale_dtype = quant_config.scale_dtype
610-
quant_block_list = quant_config.quant_block_list
610+
to_quant_block_names = quant_config.to_quant_block_names
611611
low_cpu_mem_usage = quant_config.use_layer_wise
612612
export_format = quant_config.export_format
613+
enable_norm_bias_tuning = quant_config.enable_norm_bias_tuning
614+
enable_torch_compile = quant_config.enable_torch_compile
615+
is_mllm = quant_config.is_mllm
616+
quant_nontext_module = quant_config.quant_nontext_module
617+
extra_data_dir = quant_config.extra_data_dir
618+
processor = quant_config.processor
619+
image_processor = quant_config.image_processor
620+
template = quant_config.template
621+
truncation = quant_config.truncation
613622

614623
kwargs.pop("example_inputs")
615624

@@ -635,9 +644,18 @@ def autoround_quantize_entry(
635644
not_use_best_mse=not_use_best_mse,
636645
dynamic_max_gap=dynamic_max_gap,
637646
scale_dtype=scale_dtype,
638-
quant_block_list=quant_block_list,
647+
to_quant_block_names=to_quant_block_names,
639648
low_cpu_mem_usage=low_cpu_mem_usage,
640649
export_format=export_format,
650+
enable_norm_bias_tuning=enable_norm_bias_tuning,
651+
enable_torch_compile=enable_torch_compile,
652+
is_mllm=is_mllm,
653+
quant_nontext_module=quant_nontext_module,
654+
extra_data_dir=extra_data_dir,
655+
processor=processor,
656+
image_processor=image_processor,
657+
template=template,
658+
truncation=truncation,
641659
)
642660
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
643661
model.qconfig = configs_mapping

neural_compressor/torch/quantization/config.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -943,8 +943,19 @@ def __init__(
943943
dynamic_max_gap: int = -1,
944944
scale_dtype: str = "fp16",
945945
use_layer_wise: bool = False,
946-
quant_block_list: list = None,
946+
to_quant_block_names: list = None,
947947
export_format: str = "itrex",
948+
# v0.4
949+
enable_norm_bias_tuning: bool = False,
950+
enable_torch_compile: bool = None,
951+
# mllm
952+
is_mllm: bool = False,
953+
quant_nontext_module: Union[str, list] = None,
954+
extra_data_dir: str = None,
955+
processor=None,
956+
image_processor=None,
957+
template=None,
958+
truncation: bool = False,
948959
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
949960
**kwargs,
950961
):
@@ -979,8 +990,20 @@ def __init__(
979990
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
980991
have different choices.
981992
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
982-
quant_block_list (list): A list whose elements are list of block's layer names to be quantized.
993+
to_quant_block_names (list): A list whose elements are list of block's layer names to be quantized.
983994
export_format (str, optional): The format used for exporting the quantized model. Defaults to "itrex".
995+
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
996+
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
997+
quant_nontext_module (Union[str, list]): Whether to quantize nontext module.
998+
extra_data_dir (str): The path for extra data such as images, audio or videos.
999+
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
1000+
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
1001+
decode the data that groups several modalities (among text, vision and audio).
1002+
This is handled by objects called processors, which group together two or more processing objects such
1003+
as tokenizers (for the text modality), image processors (for vision) and feature extractors (for audio).
1004+
image_processor (Processor): Image processor for special model like llava.
1005+
template (Template): The template to specify process for different mllms.
1006+
truncation (bool): Activates truncation to cut input sequences longer than `max_length` to `max_length`.
9841007
white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types.
9851008
Default is DEFAULT_WHITE_LIST.
9861009
"""
@@ -1012,8 +1035,17 @@ def __init__(
10121035
self.dynamic_max_gap = dynamic_max_gap
10131036
self.scale_dtype = scale_dtype
10141037
self.use_layer_wise = use_layer_wise
1015-
self.quant_block_list = quant_block_list
1038+
self.to_quant_block_names = to_quant_block_names
10161039
self.export_format = export_format
1040+
self.enable_norm_bias_tuning = enable_norm_bias_tuning
1041+
self.enable_torch_compile = enable_torch_compile
1042+
self.is_mllm = is_mllm
1043+
self.quant_nontext_module = quant_nontext_module
1044+
self.extra_data_dir = extra_data_dir
1045+
self.processor = processor
1046+
self.image_processor = image_processor
1047+
self.template = template
1048+
self.truncation = truncation
10171049
self._post_init()
10181050

10191051
@classmethod

0 commit comments

Comments
 (0)