Skip to content

Commit 4d73e51

Browse files
changwangssecharlaixsys-lpot-val
authored
Add INC layerwise quantization support (#1018)
* support layerwise quantization Signed-off-by: changwangss <chang1.wang@intel.com> * fix ut and example Signed-off-by: changwa1 <chang1.wang@intel.com> * improve model init Signed-off-by: changwa1 <chang1.wang@intel.com> * improve ut Signed-off-by: changwa1 <chang1.wang@intel.com> * fix loading kwargs issue Signed-off-by: changwa1 <chang1.wang@intel.com> * set neuralcompressor commit Signed-off-by: changwa1 <chang1.wang@intel.com> * Update optimum/intel/neural_compressor/quantization.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * fix lay-wise model init Signed-off-by: sys-lpot-val <sys_lpot_val@intel.com> * fix quantization_config init Signed-off-by: changwangss <chang1.wang@intel.com> * add limit for use_layer_wise Signed-off-by: changwangss <chang1.wang@intel.com> * fix load_empty_model Signed-off-by: changwangss <chang1.wang@intel.com> * Update setup.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * add version check for layerwise feature Signed-off-by: changwangss <chang1.wang@intel.com> --------- Signed-off-by: changwangss <chang1.wang@intel.com> Signed-off-by: changwa1 <chang1.wang@intel.com> Signed-off-by: sys-lpot-val <sys_lpot_val@intel.com> Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Co-authored-by: sys-lpot-val <sys_lpot_val@intel.com>
1 parent 3ba51f1 commit 4d73e51

File tree

3 files changed

+30
-19
lines changed

3 files changed

+30
-19
lines changed

examples/neural_compressor/language-modeling/run_clm.py

+6
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ class OptimizationArguments:
215215
default="sym",
216216
metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."},
217217
)
218+
use_layer_wise: bool = field(
219+
default=False,
220+
metadata={"help": "Use layer wise to do quantization to save memory."},
221+
)
218222
quantization_methodology: str = field(
219223
default="rtn",
220224
metadata={"help": "Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'."},
@@ -659,13 +663,15 @@ def compute_metrics(eval_preds):
659663
"bits": optim_args.bits,
660664
"sym": optim_args.weight_only_scheme == "sym",
661665
"group_size": optim_args.group_size,
666+
"use_layer_wise": optim_args.use_layer_wise,
662667
}
663668

664669
if optim_args.quantization_methodology == "gptq":
665670
quantization_config = GPTQConfig(
666671
damp_percent=optim_args.damp_percent,
667672
nsamples=optim_args.num_calibration_samples,
668673
blocksize=optim_args.gptq_block_size,
674+
tokenizer=tokenizer,
669675
**algorithm_args,
670676
)
671677
else:

optimum/intel/neural_compressor/quantization.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -374,22 +374,21 @@ def _weight_only_quantization(
374374
}
375375

376376
low_cpu_mem_usage = True
377-
if use_xpu:
378-
try:
379-
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
380-
model = model_class.from_pretrained(
381-
model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs
382-
)
383-
except NotImplementedError:
384-
logger.info(
385-
"Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption."
386-
)
387-
low_cpu_mem_usage = False
388-
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
389-
quantization_config.update(**{"device": "xpu"})
390-
quantization_config.post_init_xpu()
377+
378+
if getattr(quantization_config, "use_layer_wise", False):
379+
if is_neural_compressor_version(">=", "3.2"):
380+
from neural_compressor.torch import load_empty_model
381+
382+
model = load_empty_model(model_id, cls=model_class, **loading_kwargs)
383+
else:
384+
raise ValueError("INC version must be >= 3.2 when use_layer_wise is set to True in quantization_config.")
391385
else:
392386
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
387+
388+
if use_xpu:
389+
quantization_config.update(**{"device": "xpu"})
390+
quantization_config.post_init_xpu()
391+
else:
393392
quantization_config.post_init_cpu()
394393

395394
model.config.update({"low_cpu_mem_usage": low_cpu_mem_usage})

tests/neural_compressor/test_optimization.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
set_seed,
4646
)
4747
from utils_tests import MODEL_NAMES, SEED, INCTestMixin, _generate_dataset
48-
from optimum.intel.utils.import_utils import is_torch_version
48+
from optimum.intel.utils.import_utils import is_neural_compressor_version
4949

5050
from optimum.intel import (
5151
INCConfig,
@@ -467,12 +467,16 @@ def _compute_metrics(pred):
467467

468468
class WeightOnlyQuantizationTest(INCTestMixin):
469469
WEIGHT_ONLY_CONFIG = (
470-
("rtn", 4),
471-
("gptq", 4),
470+
("rtn", 4, False),
471+
("rtn", 4, True),
472+
("gptq", 4, False),
473+
("gptq", 4, True),
472474
)
473475

474476
@parameterized.expand(WEIGHT_ONLY_CONFIG)
475-
def test_weight_only_quantization(self, methodology, bits):
477+
def test_weight_only_quantization(self, methodology, bits, use_layer_wise):
478+
if use_layer_wise and is_neural_compressor_version("<", "3.2"):
479+
self.skipTest("INC version < 3.2 doesn't support layer-wise feature.")
476480
from neural_compressor.transformers import GPTQConfig, RtnConfig
477481

478482
model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM"
@@ -489,9 +493,10 @@ def test_weight_only_quantization(self, methodology, bits):
489493
batch_size=5,
490494
seq_len=32,
491495
block_size=16,
496+
use_layer_wise=use_layer_wise,
492497
)
493498
else:
494-
quantization_config = RtnConfig(bits=bits, group_size=8)
499+
quantization_config = RtnConfig(bits=bits, group_size=8, use_layer_wise=use_layer_wise)
495500

496501
tokenizer = AutoTokenizer.from_pretrained(model_name)
497502
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
@@ -503,6 +508,7 @@ def test_weight_only_quantization(self, methodology, bits):
503508
with torch.no_grad():
504509
quantizer_outputs = quantized_model(**tokens)
505510
quantized_model.save_pretrained(tmp_dir)
511+
506512
loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir)
507513
with torch.no_grad():
508514
loaded_outputs = loaded_model(**tokens)

0 commit comments

Comments
 (0)