Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support layerwise quantization #1018

Merged
merged 15 commits into from
Dec 5, 2024
6 changes: 6 additions & 0 deletions examples/neural_compressor/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ class OptimizationArguments:
default="sym",
metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."},
)
use_layer_wise: bool = field(
default=False,
metadata={"help": "Use layer wise to do quantization to save memory."},
)
quantization_methodology: str = field(
default="rtn",
metadata={"help": "Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'."},
Expand Down Expand Up @@ -659,13 +663,15 @@ def compute_metrics(eval_preds):
"bits": optim_args.bits,
"sym": optim_args.weight_only_scheme == "sym",
"group_size": optim_args.group_size,
"use_layer_wise": optim_args.use_layer_wise,
}

if optim_args.quantization_methodology == "gptq":
quantization_config = GPTQConfig(
damp_percent=optim_args.damp_percent,
nsamples=optim_args.num_calibration_samples,
blocksize=optim_args.gptq_block_size,
tokenizer=tokenizer,
**algorithm_args,
)
else:
Expand Down
37 changes: 23 additions & 14 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,22 +374,31 @@ def _weight_only_quantization(
}

low_cpu_mem_usage = True
if use_xpu:
try:
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
model = model_class.from_pretrained(
model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs
)
except NotImplementedError:
logger.info(
"Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption."
)
low_cpu_mem_usage = False

if getattr(quantization_config, "use_layer_wise", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can check here for the neural-compressor version and disable if not supported (a warning should be added as well) so that support is ready for neural-compressor next release

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, improved.

from neural_compressor.torch import load_empty_model

model = load_empty_model(model_id, cls=model_class, trust_remote_code=trust_remote_code)
else:
if use_xpu:
try:
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
model = model_class.from_pretrained(
model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs
)
except NotImplementedError:
logger.info(
"Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption."
)
low_cpu_mem_usage = False
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
else:
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
quantization_config.update(**{"device": "xpu"})
quantization_config.post_init_xpu()

if use_xpu:
quantization_config.update(**{"device": "xpu"})
quantization_config.post_init_xpu()
else:
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
quantization_config.post_init_cpu()

model.config.update({"low_cpu_mem_usage": low_cpu_mem_usage})
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@
EXTRAS_REQUIRE = {
"nncf": ["nncf>=2.14.0"],
"openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"],
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
"neural-compressor": [
"neural_compressor[pt]@git+https://github.com/intel/neural-compressor.git@5c72158a6799bdf0334ef36fbd493eeed3b62d9f",
"accelerate",
"transformers<4.46",
],
"ipex": ["intel-extension-for-pytorch>=2.2,<2.4", "transformers>=4.39,<4.45"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
Expand Down
12 changes: 8 additions & 4 deletions tests/neural_compressor/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,14 @@ def _compute_metrics(pred):

class WeightOnlyQuantizationTest(INCTestMixin):
WEIGHT_ONLY_CONFIG = (
("rtn", 4),
("gptq", 4),
("rtn", 4, False),
("rtn", 4, True),
("gptq", 4, False),
("gptq", 4, True),
)

@parameterized.expand(WEIGHT_ONLY_CONFIG)
def test_weight_only_quantization(self, methodology, bits):
def test_weight_only_quantization(self, methodology, bits, use_layer_wise):
from neural_compressor.transformers import GPTQConfig, RtnConfig

model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM"
Expand All @@ -489,9 +491,10 @@ def test_weight_only_quantization(self, methodology, bits):
batch_size=5,
seq_len=32,
block_size=16,
use_layer_wise=use_layer_wise,
)
else:
quantization_config = RtnConfig(bits=bits, group_size=8)
quantization_config = RtnConfig(bits=bits, group_size=8, use_layer_wise=use_layer_wise)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
Expand All @@ -503,6 +506,7 @@ def test_weight_only_quantization(self, methodology, bits):
with torch.no_grad():
quantizer_outputs = quantized_model(**tokens)
quantized_model.save_pretrained(tmp_dir)

loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir)
with torch.no_grad():
loaded_outputs = loaded_model(**tokens)
Expand Down