Skip to content

Commit ba46b21

Browse files
committed
add quant_nontext_module
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 56e2caf commit ba46b21

File tree

5 files changed

+58
-11
lines changed

5 files changed

+58
-11
lines changed

neural_compressor/torch/algorithms/weight_only/autoround.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
enable_torch_compile: bool = None,
8585
# mllm
8686
is_mllm: bool = False,
87-
quant_nontext_module: Union[str, list] = None,
87+
quant_nontext_module: bool = False,
8888
extra_data_dir: str = None,
8989
image_processor=None,
9090
processor=None,
@@ -150,7 +150,7 @@ def __init__(
150150
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
151151
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
152152
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
153-
quant_nontext_module (Union[str, list]): Whether to quantize nontext module.
153+
quant_nontext_module (bool): Whether to quantize nontext module.
154154
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
155155
extra_data_dir (str): The path for extra data such as images, audio or videos.
156156
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
@@ -383,7 +383,7 @@ def get_mllm_dataloader(
383383
template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor
384384
)
385385
dataset = template.default_dataset if dataset is None else dataset
386-
if quant_nontext_module or (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer)):
386+
if quant_nontext_module or (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer, "cpu", template.model_type)):
387387
if quant_nontext_module:
388388
logger.warning(
389389
"Quantitative nontext module is not supported for plain text datasets,"
@@ -399,7 +399,7 @@ def get_mllm_dataloader(
399399
truncation = False
400400
gradient_accumulate_steps = batch_size * gradient_accumulate_steps
401401
batch_size = 1
402-
402+
seed = 42 # The seed is fixed to 42 in transformers
403403
seqlen = 2048 if seqlen is None else seqlen # set text only calibration default args
404404
truncation = True if truncation is None else truncation
405405
dataset = dataset.replace(" ", "")

neural_compressor/torch/quantization/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def __init__(
950950
enable_torch_compile: bool = None,
951951
# mllm
952952
is_mllm: bool = False,
953-
quant_nontext_module: Union[str, list] = None,
953+
quant_nontext_module: bool = False,
954954
extra_data_dir: str = None,
955955
processor=None,
956956
image_processor=None,
@@ -994,7 +994,7 @@ def __init__(
994994
export_format (str, optional): The format used for exporting the quantized model. Defaults to "itrex".
995995
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
996996
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
997-
quant_nontext_module (Union[str, list]): Whether to quantize nontext module.
997+
quant_nontext_module (bool): Whether to quantize nontext module.
998998
extra_data_dir (str): The path for extra data such as images, audio or videos.
999999
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
10001000
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or

neural_compressor/transformers/quantization/utils.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
import os
2020
import types
21+
import re
2122

2223
from datasets import load_dataset
2324

@@ -40,6 +41,7 @@
4041

4142
if is_package_available("auto_round"):
4243
import auto_round
44+
import transformers
4345
from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear as auto_round_woq_linear
4446

4547

@@ -132,18 +134,18 @@ def _replace_linear(
132134
isinstance(module, torch.nn.Linear)
133135
or isinstance(module, INCWeightOnlyLinear)
134136
or (is_package_available("auto_round") and isinstance(module, auto_round_woq_linear))
135-
or (is_ipex_available() and isinstance(module, ipex.nn.utils._weight_prepack._IPEXLinear))
136137
) and (name not in modules_to_not_convert):
137138
# Check if the current key is not in the `modules_to_not_convert`
138-
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
139+
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and \
140+
not any(re.match(pattern, ".".join(current_key_name)) for pattern in modules_to_not_convert):
139141
in_features = module.in_features
140142
out_features = module.out_features
141143
if device == "cpu" or device == torch.device("cpu") or device == "auto":
142144
from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear
143145
from intel_extension_for_pytorch.utils.weight_only_quantization import (
144146
_convert_optimum_format_to_desired,
145147
)
146-
148+
147149
qweight = module.qweight
148150
scales = module.scales
149151
qzeros = module.qzeros
@@ -550,7 +552,41 @@ def convert_to_quantized_model(model, config, device="cpu"):
550552
gradient_accumulate_steps=config.gradient_accumulate_steps,
551553
export_format=config.export_format,
552554
)
553-
555+
556+
# vlm set non-text module config
557+
if config.is_vlm is True:
558+
from neural_compressor.torch.utils.utility import (
559+
get_multimodal_block_names,
560+
find_matching_blocks,
561+
get_layer_names_in_block,
562+
)
563+
def set_nontext_module_config(model, to_quant_block_names, config):
564+
all_block_list = get_multimodal_block_names(model, quant_vision=True)
565+
all_block_set = set(tuple(block) for block in all_block_list)
566+
quant_block_set = set(tuple(block) for block in to_quant_block_names)
567+
set_to_full_prec = list(all_block_set - quant_block_set)
568+
set_to_full_prec = get_layer_names_in_block(model, to_quant_block_names=set_to_full_prec)
569+
for name in set_to_full_prec:
570+
config.modules_to_not_convert.append(name)
571+
572+
# skip layers not in blocks
573+
config.modules_to_not_convert.append("model.vision_embed_tokens.img_projection*")
574+
config.modules_to_not_convert.append("transformer.visual.attn_pool.*_proj")
575+
config.modules_to_not_convert.append("model.mm_projector*")
576+
config.modules_to_not_convert.append("multi_modal_projector")
577+
config.modules_to_not_convert.append("visual.merger")
578+
579+
all_blocks = get_multimodal_block_names(model, quant_config.quant_nontext_module)
580+
to_quant_block_names = find_matching_blocks(model, all_blocks, quant_config.to_quant_block_names)
581+
set_nontext_module_config(model, to_quant_block_names, config)
582+
583+
for n, m in model.named_modules():
584+
if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D):
585+
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
586+
config.modules_to_not_convert.append(n)
587+
print(
588+
f"{n} will not be quantized due to its shape not being divisible by 32,"
589+
" resulting in an exporting issue to autogptq")
554590
if config.modules_to_not_convert != []:
555591
for module in config.modules_to_not_convert:
556592
module_name = ".*" + module

neural_compressor/transformers/utils/quantization_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def __init__(
545545
quant_lm_head: bool = False,
546546
# vlm arguments
547547
is_vlm: bool = False,
548-
quant_nontext_module: Union[str, list] = None,
548+
quant_nontext_module: bool = False,
549549
truncation: bool = False,
550550
gradient_accumulate_steps: int = 1,
551551
export_format="itrex",

test/3x/torch/quantization/weight_only/test_transformers.py

+11
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,17 @@ def test_vlm(self):
249249
assert isinstance(loaded_model.model.layers[0].self_attn.k_proj, WeightOnlyQuantizedLinear), "loaing model failed."
250250

251251
# phi-3-vision-128k-instruct
252+
woq_config = AutoRoundConfig(
253+
bits=4,
254+
group_size=128,
255+
is_vlm=True,
256+
dataset="NeelNanda/pile-10k",
257+
iters=2,
258+
n_samples=5,
259+
seq_len=64,
260+
batch_size=1,
261+
)
252262
model_name = "microsoft/Phi-3-vision-128k-instruct"
253263
woq_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True, attn_implementation='eager')
254264
assert isinstance(woq_model.model.layers[0].self_attn.o_proj, WeightOnlyQuantizedLinear), "quantizaion failed."
265+

0 commit comments

Comments
 (0)