@@ -33,8 +33,9 @@ def _is_auto_round_available():
33
33
34
34
_is_auto_round_available ()
35
35
36
- from auto_round import AutoRound # pylint: disable=E0401
36
+ from auto_round import AutoRound , AutoRoundMLLM # pylint: disable=E0401
37
37
from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
38
+ from auto_round .mllm .template import Template , get_template
38
39
39
40
from neural_compressor .torch .algorithms import Quantizer
40
41
from neural_compressor .torch .utils import get_accelerator , logger
@@ -70,13 +71,24 @@ def __init__(
70
71
dynamic_max_gap : int = - 1 ,
71
72
data_type : str = "int" ,
72
73
scale_dtype : str = "fp16" ,
73
- quant_block_list : list = None ,
74
+ to_quant_block_names : list = None ,
74
75
act_bits : int = 32 ,
75
76
act_group_size : int = None ,
76
77
act_sym : bool = None ,
77
78
act_dynamic : bool = True ,
78
79
low_cpu_mem_usage : bool = False ,
79
80
export_format : str = "itrex" ,
81
+ # v0.4
82
+ enable_norm_bias_tuning : bool = False ,
83
+ enable_torch_compile : bool = None ,
84
+ # mllm
85
+ is_mllm : bool = False ,
86
+ quant_nontext_module : Union [str , list ] = None ,
87
+ extra_data_dir : str = None ,
88
+ image_processor = None ,
89
+ processor = None ,
90
+ template : Union [str , Template ] = None ,
91
+ truncation : bool = False ,
80
92
** kwargs ,
81
93
):
82
94
"""Init a AutQRoundQuantizer object.
@@ -130,11 +142,23 @@ def __init__(
130
142
data_type (str): The data type to be used (default is "int").
131
143
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
132
144
have different choices.
133
- quant_block_list (list): A list whose elements are list of block's layer names to be quantized.
145
+ to_quant_block_names (list): A list whose elements are list of block's layer names to be quantized.
134
146
act_bits (int): Number of bits for activation quantization. Default is 32.
135
147
act_group_size (int): Group size for activation quantization. Default is None.
136
148
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
137
149
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
150
+ enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
151
+ enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
152
+ quant_nontext_module (Union[str, list]): Whether to quantize nontext module.
153
+ is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
154
+ extra_data_dir (str): The path for extra data such as images, audio or videos.
155
+ processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
156
+ decode the data that groups several modalities (among text, vision and audio).
157
+ This is handled by objects called processors, which group together two or more processing objects such
158
+ as tokenizers (for the text modality), image processors (for vision) and feature extractors (for audio).
159
+ image_processor (Processor): Image processor for special model like llava.
160
+ template (Template): The template to specify process for different mllms.
161
+ truncation (bool): Activates truncation to cut input sequences longer than `max_length` to `max_length`.
138
162
139
163
Returns:
140
164
The quantized model.
@@ -162,13 +186,22 @@ def __init__(
162
186
self .dynamic_max_gap = dynamic_max_gap
163
187
self .data_type = data_type
164
188
self .scale_dtype = scale_dtype
165
- self .quant_block_list = quant_block_list
189
+ self .to_quant_block_names = to_quant_block_names
166
190
self .act_bits = act_bits
167
191
self .act_group_size = act_group_size
168
192
self .act_sym = act_sym
169
193
self .act_dynamic = act_dynamic
170
194
self .low_cpu_mem_usage = low_cpu_mem_usage
171
195
self .export_format = export_format
196
+ self .enable_norm_bias_tuning = enable_norm_bias_tuning
197
+ self .enable_torch_compile = enable_torch_compile
198
+ self .is_mllm = is_mllm
199
+ self .quant_nontext_module = quant_nontext_module
200
+ self .extra_data_dir = extra_data_dir
201
+ self .processor = processor
202
+ self .image_processor = image_processor
203
+ self .template = template
204
+ self .truncation = truncation
172
205
173
206
def prepare (self , model : torch .nn .Module , * args , ** kwargs ):
174
207
"""Prepares a given model for quantization.
@@ -193,39 +226,83 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
193
226
"""
194
227
dataloader = CapturedDataloader (model .args_list , model .kwargs_list )
195
228
model = model .orig_model
196
- rounder = AutoRound (
197
- model = model ,
198
- tokenizer = None ,
199
- dataset = dataloader ,
200
- layer_config = self .quant_config or {},
201
- enable_full_range = self .enable_full_range ,
202
- batch_size = self .batch_size ,
203
- amp = self .amp ,
204
- device = self .device ,
205
- lr_scheduler = self .lr_scheduler ,
206
- enable_quanted_input = self .enable_quanted_input ,
207
- enable_minmax_tuning = self .enable_minmax_tuning ,
208
- lr = self .lr ,
209
- minmax_lr = self .minmax_lr ,
210
- low_gpu_mem_usage = self .low_gpu_mem_usage ,
211
- iters = self .iters ,
212
- seqlen = self .seqlen ,
213
- nsamples = self .nsamples ,
214
- sampler = self .sampler ,
215
- seed = self .seed ,
216
- nblocks = self .nblocks ,
217
- gradient_accumulate_steps = self .gradient_accumulate_steps ,
218
- not_use_best_mse = self .not_use_best_mse ,
219
- dynamic_max_gap = self .dynamic_max_gap ,
220
- data_type = self .data_type ,
221
- scale_dtype = self .scale_dtype ,
222
- quant_block_list = self .quant_block_list ,
223
- act_bits = self .act_bits ,
224
- act_group_size = self .act_group_size ,
225
- act_sym = self .act_sym ,
226
- act_dynamic = self .act_dynamic ,
227
- low_cpu_mem_usage = self .low_cpu_mem_usage ,
228
- )
229
+ if self .is_mllm :
230
+ rounder = AutoRoundMLLM (
231
+ model ,
232
+ tokenizer = None ,
233
+ processor = self .processor ,
234
+ image_processor = self .image_processor ,
235
+ layer_config = self .quant_config ,
236
+ batch_size = self .batch_size ,
237
+ amp = self .amp ,
238
+ device = self .device ,
239
+ lr_scheduler = self .lr_scheduler ,
240
+ dataset = dataloader ,
241
+ extra_data_dir = self .extra_data_dir ,
242
+ template = self .template ,
243
+ quant_nontext_module = self .quant_nontext_module ,
244
+ enable_quanted_input = self .enable_quanted_input ,
245
+ enable_minmax_tuning = self .enable_minmax_tuning ,
246
+ lr = self .lr ,
247
+ minmax_lr = self .minmax_lr ,
248
+ low_gpu_mem_usage = self .low_gpu_mem_usage ,
249
+ low_cpu_mem_usage = self .low_gpu_mem_usage ,
250
+ iters = self .iters ,
251
+ seqlen = self .seqlen ,
252
+ nsamples = self .nsamples ,
253
+ sampler = self .sampler ,
254
+ seed = self .seed ,
255
+ nblocks = self .nblocks ,
256
+ gradient_accumulate_steps = self .gradient_accumulate_steps ,
257
+ not_use_best_mse = self .not_use_best_mse ,
258
+ dynamic_max_gap = self .dynamic_max_gap ,
259
+ data_type = self .data_type ,
260
+ scale_dtype = self .scale_dtype ,
261
+ act_bits = self .act_bits ,
262
+ act_group_size = self .act_group_size ,
263
+ act_sym = self .act_sym ,
264
+ act_dynamic = self .act_dynamic ,
265
+ to_quant_block_names = self .to_quant_block_names ,
266
+ enable_norm_bias_tuning = self .enable_norm_bias_tuning ,
267
+ truncation = self .truncation ,
268
+ enable_torch_compile = self .enable_torch_compile ,
269
+ )
270
+ else :
271
+ rounder = AutoRound (
272
+ model = model ,
273
+ tokenizer = None ,
274
+ dataset = dataloader ,
275
+ layer_config = self .quant_config or {},
276
+ enable_full_range = self .enable_full_range ,
277
+ batch_size = self .batch_size ,
278
+ amp = self .amp ,
279
+ device = self .device ,
280
+ lr_scheduler = self .lr_scheduler ,
281
+ enable_quanted_input = self .enable_quanted_input ,
282
+ enable_minmax_tuning = self .enable_minmax_tuning ,
283
+ lr = self .lr ,
284
+ minmax_lr = self .minmax_lr ,
285
+ low_gpu_mem_usage = self .low_gpu_mem_usage ,
286
+ iters = self .iters ,
287
+ seqlen = self .seqlen ,
288
+ nsamples = self .nsamples ,
289
+ sampler = self .sampler ,
290
+ seed = self .seed ,
291
+ nblocks = self .nblocks ,
292
+ gradient_accumulate_steps = self .gradient_accumulate_steps ,
293
+ not_use_best_mse = self .not_use_best_mse ,
294
+ dynamic_max_gap = self .dynamic_max_gap ,
295
+ data_type = self .data_type ,
296
+ scale_dtype = self .scale_dtype ,
297
+ to_quant_block_names = self .to_quant_block_names ,
298
+ act_bits = self .act_bits ,
299
+ act_group_size = self .act_group_size ,
300
+ act_sym = self .act_sym ,
301
+ act_dynamic = self .act_dynamic ,
302
+ low_cpu_mem_usage = self .low_cpu_mem_usage ,
303
+ enable_norm_bias_tuning = self .enable_norm_bias_tuning ,
304
+ enable_torch_compile = self .enable_torch_compile ,
305
+ )
229
306
model , weight_config = rounder .quantize ()
230
307
model .autoround_config = weight_config
231
308
if "itrex" in self .export_format :
@@ -259,3 +336,82 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
259
336
tokenizer , seqlen , dataset_name = "NeelNanda/pile-10k" , seed = seed , bs = bs , nsamples = nsamples
260
337
)
261
338
return dataloader
339
+
340
+
341
+ def get_mllm_dataloader (
342
+ template ,
343
+ model ,
344
+ tokenizer ,
345
+ processor = None ,
346
+ image_processor = None ,
347
+ dataset = "liuhaotian/llava_conv_58k" ,
348
+ extra_data_dir = None ,
349
+ seqlen = 512 ,
350
+ bs = 1 ,
351
+ split = None ,
352
+ apply_template = None ,
353
+ truncation = False ,
354
+ seed = 42 ,
355
+ nsamples = 512 ,
356
+ gradient_accumulate_steps = 1 ,
357
+ quant_nontext_module = False ,
358
+ ):
359
+ """Generate a DataLoader for calibration using specified parameters.
360
+
361
+ Args:
362
+ template (Template): The template to specify process for different mllms.
363
+ model (Model): The model to quantized.
364
+ tokenizer (Tokenizer): The tokenizer to use for tokenization.
365
+ Dataset_name (str): The name or path of the dataset.
366
+ extra_data_dir (str): The path for extra data such as images, audio or videos.
367
+ seqlen (int): The exact sequence length. samples < seqlen will be dropped,
368
+ samples longer than seqlen will be truncated
369
+ bs (int, optional): The batch size. Defaults to 4.
370
+ split (str, optional): The data split to use. Defaults to None.
371
+ apply_template: Whether to apply chat template in tokenization.
372
+
373
+ Returns:
374
+ DataLoader: The DataLoader for the calibrated datasets.
375
+ """
376
+ from auto_round .calib_dataset import CALIB_DATASETS
377
+ from auto_round .mllm .autoround_mllm import _only_text_test
378
+ from auto_round .mllm .mllm_dataset import get_mllm_dataloader # pylint: disable=E0401
379
+
380
+ if quant_nontext_module or (dataset in CALIB_DATASETS .keys () and not _only_text_test (model , tokenizer )):
381
+ if quant_nontext_module :
382
+ logger .warning (
383
+ "Quantitative nontext module is not supported for plain text datasets,"
384
+ "will use liuhaotian/llava_conv_58k with default config as an alternative."
385
+ )
386
+ else :
387
+ logger .warning (
388
+ f"{ model .config .model_type } not support for { dataset } ,"
389
+ " will use liuhaotian/llava_conv_58k with default config as an alternative."
390
+ )
391
+ dataset = "liuhaotian/llava_conv_58k"
392
+ truncation = False
393
+ batch_size = 1
394
+ gradient_accumulate_steps = 4
395
+ seqlen = 512
396
+
397
+ dataset = dataset .replace (" " , "" )
398
+ template = template if template is not None else model .config .model_type
399
+ template = get_template (
400
+ template , model = model , tokenizer = tokenizer , processor = processor , image_processor = image_processor
401
+ )
402
+ dataloader , batch_size , gradient_accumulate_steps = get_mllm_dataloader (
403
+ template = template ,
404
+ model = model ,
405
+ tokenizer = tokenizer ,
406
+ image_processor = image_processor ,
407
+ dataset = dataset ,
408
+ extra_data_dir = extra_data_dir ,
409
+ seqlen = seqlen ,
410
+ bs = bs ,
411
+ seed = seed ,
412
+ truncation = truncation ,
413
+ nsamples = nsamples ,
414
+ gradient_accumulate_steps = gradient_accumulate_steps ,
415
+ quant_nontext_module = quant_nontext_module ,
416
+ )
417
+ return dataloader , template , truncation , batch_size , gradient_accumulate_steps , seqlen
0 commit comments