Skip to content

Commit 6bb52cf

Browse files
Kaihui-intelpre-commit-ci[bot]XuehaoSun
authored
Add VLM quantization & loading into transformers-like API (#2116)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sun, Xuehao <xuehao.sun@intel.com>
1 parent 9c3d4a1 commit 6bb52cf

File tree

11 files changed

+226
-41
lines changed

11 files changed

+226
-41
lines changed

.azure-pipelines/scripts/ut/3x/run_3x_pt.sh

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ echo "##[section]import check pass"
1313
echo "##[group]set up UT env..."
1414
export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH
1515
pip install -r /neural-compressor/test/3x/torch/requirements.txt
16+
pip install torch==2.5.1 torchvision==0.20.1 # For auto-round
1617
pip install pytest-cov
1718
pip install pytest-html
1819
echo "##[endgroup]"

neural_compressor/torch/algorithms/weight_only/autoround.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
enable_torch_compile: bool = None,
8585
# mllm
8686
is_mllm: bool = False,
87-
quant_nontext_module: Union[str, list] = None,
87+
quant_nontext_module: bool = False,
8888
extra_data_dir: str = None,
8989
image_processor=None,
9090
processor=None,
@@ -150,7 +150,7 @@ def __init__(
150150
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
151151
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
152152
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
153-
quant_nontext_module (Union[str, list]): Whether to quantize nontext module.
153+
quant_nontext_module (bool): Whether to quantize nontext module.
154154
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
155155
extra_data_dir (str): The path for extra data such as images, audio or videos.
156156
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
@@ -383,7 +383,9 @@ def get_mllm_dataloader(
383383
template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor
384384
)
385385
dataset = template.default_dataset if dataset is None else dataset
386-
if quant_nontext_module or (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer)):
386+
if quant_nontext_module or (
387+
dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer, "cpu", template.model_type)
388+
):
387389
if quant_nontext_module:
388390
logger.warning(
389391
"Quantitative nontext module is not supported for plain text datasets,"
@@ -399,7 +401,7 @@ def get_mllm_dataloader(
399401
truncation = False
400402
gradient_accumulate_steps = batch_size * gradient_accumulate_steps
401403
batch_size = 1
402-
404+
seed = 42 # The seed is fixed to 42 in transformers
403405
seqlen = 2048 if seqlen is None else seqlen # set text only calibration default args
404406
truncation = True if truncation is None else truncation
405407
dataset = dataset.replace(" ", "")

neural_compressor/torch/quantization/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def __init__(
950950
enable_torch_compile: bool = None,
951951
# mllm
952952
is_mllm: bool = False,
953-
quant_nontext_module: Union[str, list] = None,
953+
quant_nontext_module: bool = False,
954954
extra_data_dir: str = None,
955955
processor=None,
956956
image_processor=None,
@@ -994,7 +994,7 @@ def __init__(
994994
export_format (str, optional): The format used for exporting the quantized model. Defaults to "itrex".
995995
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
996996
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
997-
quant_nontext_module (Union[str, list]): Whether to quantize nontext module.
997+
quant_nontext_module (bool): Whether to quantize nontext module.
998998
extra_data_dir (str): The path for extra data such as images, audio or videos.
999999
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
10001000
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or

neural_compressor/transformers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
AutoModelForCausalLM,
2424
AutoModel,
2525
AutoModelForSeq2SeqLM,
26+
Qwen2VLForConditionalGeneration,
2627
)

neural_compressor/transformers/models/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,11 @@
1313
# limitations under the License.
1414

1515
from .modeling_auto import _BaseINCAutoModelClass
16-
from .modeling_auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
16+
from .modeling_auto import (
17+
AutoModel,
18+
AutoModelForCausalLM,
19+
AutoModelForSeq2SeqLM,
20+
Qwen2VLForConditionalGeneration,
21+
MllamaForConditionalGeneration,
22+
LlavaForConditionalGeneration,
23+
)

neural_compressor/transformers/models/modeling_auto.py

+33-18
Original file line numberDiff line numberDiff line change
@@ -354,24 +354,27 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
354354
else:
355355
commit_hash = getattr(config, "_commit_hash", None)
356356

357-
has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map
358-
359-
has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys()
360-
trust_remote_code = resolve_trust_remote_code(
361-
trust_remote_code,
362-
pretrained_model_name_or_path,
363-
has_local_code,
364-
has_remote_code,
365-
)
366-
if has_remote_code and trust_remote_code:
367-
class_ref = config.auto_map[cls.ORIG_MODEL.__name__]
368-
model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig)
369-
if os.path.isdir(pretrained_model_name_or_path):
370-
model_class.register_for_auto_class(cls.ORIG_MODEL.__name__)
371-
else:
372-
cls.ORIG_MODEL.register(config.__class__, model_class, exist_ok=True)
373-
elif type(config) in cls.ORIG_MODEL._model_mapping.keys():
374-
model_class = _get_model_class(config, cls.ORIG_MODEL._model_mapping)
357+
if "AutoModel" in cls.ORIG_MODEL.__name__:
358+
has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map
359+
has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys()
360+
361+
trust_remote_code = resolve_trust_remote_code(
362+
trust_remote_code,
363+
pretrained_model_name_or_path,
364+
has_local_code,
365+
has_remote_code,
366+
)
367+
if has_remote_code and trust_remote_code:
368+
class_ref = config.auto_map[cls.ORIG_MODEL.__name__]
369+
model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs_orig)
370+
if os.path.isdir(pretrained_model_name_or_path):
371+
model_class.register_for_auto_class(cls.ORIG_MODEL.__name__)
372+
else:
373+
cls.ORIG_MODEL.register(config.__class__, model_class, exist_ok=True)
374+
elif type(config) in cls.ORIG_MODEL._model_mapping.keys():
375+
model_class = _get_model_class(config, cls.ORIG_MODEL._model_mapping)
376+
else:
377+
model_class = cls.ORIG_MODEL
375378

376379
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
377380
# index of the files.
@@ -747,3 +750,15 @@ class AutoModel(_BaseINCAutoModelClass):
747750

748751
class AutoModelForSeq2SeqLM(_BaseINCAutoModelClass):
749752
ORIG_MODEL = transformers.AutoModelForSeq2SeqLM
753+
754+
755+
class Qwen2VLForConditionalGeneration(_BaseINCAutoModelClass):
756+
ORIG_MODEL = transformers.Qwen2VLForConditionalGeneration
757+
758+
759+
class MllamaForConditionalGeneration(_BaseINCAutoModelClass):
760+
ORIG_MODEL = transformers.MllamaForConditionalGeneration
761+
762+
763+
class LlavaForConditionalGeneration(_BaseINCAutoModelClass):
764+
ORIG_MODEL = transformers.LlavaForConditionalGeneration

neural_compressor/transformers/quantization/utils.py

+103-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import math
1919
import os
20+
import re
2021
import types
2122

2223
from datasets import load_dataset
@@ -33,11 +34,16 @@
3334
convert,
3435
prepare,
3536
)
36-
from neural_compressor.torch.utils import is_ipex_available
37+
from neural_compressor.torch.utils import is_ipex_available, is_package_available
3738

3839
if is_ipex_available():
3940
import intel_extension_for_pytorch as ipex
4041

42+
if is_package_available("auto_round"):
43+
import auto_round
44+
import transformers
45+
from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear as auto_round_woq_linear
46+
4147
from typing import Union
4248

4349
torch = LazyImport("torch")
@@ -126,10 +132,12 @@ def _replace_linear(
126132
if (
127133
isinstance(module, torch.nn.Linear)
128134
or isinstance(module, INCWeightOnlyLinear)
129-
or (is_ipex_available() and isinstance(module, ipex.nn.utils._weight_prepack._IPEXLinear))
135+
or (is_package_available("auto_round") and isinstance(module, auto_round_woq_linear))
130136
) and (name not in modules_to_not_convert):
131137
# Check if the current key is not in the `modules_to_not_convert`
132-
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
138+
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and not any(
139+
re.match(pattern, ".".join(current_key_name)) for pattern in modules_to_not_convert
140+
):
133141
in_features = module.in_features
134142
out_features = module.out_features
135143
if device == "cpu" or device == torch.device("cpu") or device == "auto":
@@ -475,6 +483,54 @@ def convert_to_quantized_model(model, config, device="cpu"):
475483
run_fn(model, *run_args)
476484
model = convert(model)
477485
elif config.quant_method.value == "autoround":
486+
if config.is_vlm is True:
487+
from transformers import AutoProcessor, AutoTokenizer
488+
489+
from neural_compressor.torch.algorithms.weight_only.autoround import (
490+
get_mllm_dataloader as get_autoround_dataloader,
491+
)
492+
493+
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
494+
processor = AutoProcessor.from_pretrained(model.config._name_or_path, trust_remote_code=True)
495+
(
496+
dataloader,
497+
template,
498+
config.truncation,
499+
config.batch_size,
500+
config.gradient_accumulate_steps,
501+
config.seq_len,
502+
config.n_samples,
503+
) = get_autoround_dataloader(
504+
template=None,
505+
model=model,
506+
tokenizer=tokenizer,
507+
image_processor=None,
508+
dataset=config.dataset,
509+
extra_data_dir=None,
510+
seqlen=config.seq_len,
511+
batch_size=config.batch_size,
512+
split=None,
513+
apply_template=None,
514+
truncation=False,
515+
nsamples=config.n_samples,
516+
seed=42,
517+
gradient_accumulate_steps=config.gradient_accumulate_steps,
518+
quant_nontext_module=config.quant_nontext_module,
519+
processor=processor,
520+
)
521+
else:
522+
from neural_compressor.torch.algorithms.weight_only.autoround import (
523+
get_dataloader as get_autoround_dataloader,
524+
)
525+
526+
dataloader = get_autoround_dataloader(
527+
tokenizer=config.tokenizer,
528+
seqlen=config.seq_len,
529+
dataset_name=config.dataset,
530+
seed=42,
531+
bs=config.batch_size,
532+
nsamples=config.n_samples,
533+
)
478534
quant_config = AutoRoundConfig(
479535
dtype=dtype,
480536
bits=config.bits,
@@ -486,24 +542,59 @@ def convert_to_quantized_model(model, config, device="cpu"):
486542
seqlen=config.seq_len,
487543
nsamples=config.n_samples,
488544
iters=config.iters,
545+
batch_size=config.batch_size,
489546
scale_dtype=config.scale_dtype,
490547
use_layer_wise=config.use_layer_wise,
548+
# vlm arguments
549+
is_mllm=config.is_vlm,
550+
quant_nontext_module=config.quant_nontext_module,
551+
truncation=config.truncation,
552+
gradient_accumulate_steps=config.gradient_accumulate_steps,
553+
export_format=config.export_format,
491554
)
555+
556+
# vlm set non-text module config
557+
if config.is_vlm is True:
558+
from neural_compressor.torch.utils.utility import (
559+
find_matching_blocks,
560+
get_layer_names_in_block,
561+
get_multimodal_block_names,
562+
)
563+
564+
def set_nontext_module_config(model, to_quant_block_names, config):
565+
all_block_list = get_multimodal_block_names(model, quant_vision=True)
566+
all_block_set = set(tuple(block) for block in all_block_list)
567+
quant_block_set = set(tuple(block) for block in to_quant_block_names)
568+
set_to_full_prec = list(all_block_set - quant_block_set)
569+
set_to_full_prec = get_layer_names_in_block(model, to_quant_block_names=set_to_full_prec)
570+
for name in set_to_full_prec:
571+
config.modules_to_not_convert.append(name)
572+
573+
# skip layers not in blocks
574+
config.modules_to_not_convert.append("model.vision_embed_tokens.img_projection*")
575+
config.modules_to_not_convert.append("transformer.visual.attn_pool.*_proj")
576+
config.modules_to_not_convert.append("model.mm_projector*")
577+
config.modules_to_not_convert.append("multi_modal_projector")
578+
config.modules_to_not_convert.append("visual.merger")
579+
580+
all_blocks = get_multimodal_block_names(model, quant_config.quant_nontext_module)
581+
to_quant_block_names = find_matching_blocks(model, all_blocks, quant_config.to_quant_block_names)
582+
set_nontext_module_config(model, to_quant_block_names, config)
583+
584+
for n, m in model.named_modules():
585+
if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D):
586+
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
587+
config.modules_to_not_convert.append(n)
588+
print(
589+
f"{n} will not be quantized due to its shape not being divisible by 32,"
590+
" resulting in an exporting issue to autogptq"
591+
)
492592
if config.modules_to_not_convert != []:
493593
for module in config.modules_to_not_convert:
494594
module_name = ".*" + module
495595
quant_config.set_local(module_name, AutoRoundConfig(dtype="fp32"))
496596
logger.info(f"Do AutoRound algorithm with config {quant_config}")
497-
from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader as get_autoround_dataloader
498597

499-
dataloader = get_autoround_dataloader(
500-
tokenizer=config.tokenizer,
501-
seqlen=config.seq_len,
502-
dataset_name=config.dataset,
503-
seed=42,
504-
bs=config.batch_size,
505-
nsamples=config.n_samples,
506-
)
507598
run_fn = run_fn_for_autoround
508599
run_args = (dataloader,)
509600
model = prepare(model=model, quant_config=quant_config)

neural_compressor/transformers/utils/quantization_config.py

+13
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,12 @@ def __init__(
543543
iters: int = 200,
544544
use_layer_wise: bool = None,
545545
quant_lm_head: bool = False,
546+
# vlm arguments
547+
is_vlm: bool = False,
548+
quant_nontext_module: bool = False,
549+
truncation: bool = False,
550+
gradient_accumulate_steps: int = 1,
551+
export_format="itrex",
546552
**kwargs,
547553
):
548554

@@ -594,6 +600,13 @@ def __init__(
594600
self.use_layer_wise = use_layer_wise
595601
self.model_path = kwargs.get("model_path", "")
596602

603+
# vlm arguments
604+
self.is_vlm = is_vlm
605+
self.quant_nontext_module = quant_nontext_module
606+
self.truncation = truncation
607+
self.gradient_accumulate_steps = gradient_accumulate_steps
608+
self.export_format = export_format
609+
597610
def to_diff_dict(self) -> Dict[str, Any]:
598611
"""Removes all attributes from config which correspond to the default config attributes
599612
for better readability and serializes to a Python dictionary.

test/3x/torch/quantization/weight_only/test_autoround.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,13 @@ def test_mllm(self):
238238
image_processor=None,
239239
dataset="liuhaotian/llava_conv_58k",
240240
extra_data_dir=None,
241-
seqlen=512,
241+
seqlen=32,
242242
batch_size=1,
243243
split=None,
244244
apply_template=None,
245245
truncation=False,
246246
seed=42,
247-
nsamples=5,
247+
nsamples=1,
248248
gradient_accumulate_steps=1,
249249
quant_nontext_module=False,
250250
processor=processor,
@@ -253,9 +253,9 @@ def test_mllm(self):
253253
bits=4,
254254
group_size=128,
255255
is_mllm=True,
256-
nsamples=5,
256+
nsamples=1,
257257
batch_size=batch_size,
258-
iters=2,
258+
iters=1,
259259
seqlen=seqlen,
260260
quant_nontext_module=False,
261261
truncation=truncation,

0 commit comments

Comments
 (0)