Skip to content

Commit 649e6b1

Browse files
Support LayerWise for RTN/GPTQ (#1883)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: chensuyue <suyue.chen@intel.com>
1 parent de43d85 commit 649e6b1

File tree

13 files changed

+440
-38
lines changed

13 files changed

+440
-38
lines changed

.azure-pipelines/scripts/codeScan/pylint/pylint.sh

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ apt-get install -y --no-install-recommends --fix-missing \
2020
build-essential
2121

2222
pip install -r /neural-compressor/requirements.txt
23+
pip install -r /neural-compressor/requirements_pt.txt
2324
pip install cmake
2425

2526
pip install torch \

neural_compressor/torch/algorithms/layer_wise/load.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
_open_zipfile_reader,
3333
)
3434

35-
from neural_compressor.adaptor.torch_utils.layer_wise_quant import modified_pickle as pickle
35+
from neural_compressor.torch.algorithms.layer_wise import modified_pickle as pickle
3636

3737
from .utils import torch
3838

neural_compressor/torch/algorithms/layer_wise/utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
from transformers.models.auto.auto_factory import _BaseAutoModelClass
2828

2929
from neural_compressor.common import options
30+
from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear
3031

3132
from .load import load
3233

33-
LWQ_WORKSPACE = os.path.join(options.workspace, "layer_wise_tmp")
34+
LWQ_WORKSPACE = os.path.join(options.workspace, "lwq_tmpdir")
3435

3536

3637
class QDQLayer(torch.nn.Module):
@@ -215,6 +216,9 @@ def _get_path(pretrained_model_name_or_path):
215216
return path
216217

217218

219+
get_path = _get_path
220+
221+
218222
def load_value(model, param_name, path):
219223
if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True):
220224
input_embeddings = model.get_input_embeddings()
@@ -281,6 +285,12 @@ def clean_module_weight(module):
281285
else:
282286
submodule = module
283287

288+
if isinstance(module, WeightOnlyLinear):
289+
for n, m in submodule._buffers.items():
290+
old_value = getattr(submodule, n)
291+
with torch.no_grad():
292+
submodule._buffers[n] = torch.zeros(old_value.shape, device="meta")
293+
284294
for n, m in submodule.named_parameters():
285295
is_buffer = n in submodule._buffers
286296
old_value = getattr(submodule, n)

neural_compressor/torch/algorithms/weight_only/gptq.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,13 @@ def __init__(
230230

231231
# device
232232
self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
233-
self.model.to(self.device)
233+
if not use_layer_wise:
234+
self.model.to(self.device)
234235
self.is_ready = False
235236

236237
self.use_layer_wise = use_layer_wise
237-
self.model_path = model_path
238+
if use_layer_wise:
239+
self.prepare_layer_wise(model_path)
238240

239241
# dataloader
240242
self.use_max_length = use_max_length
@@ -243,6 +245,20 @@ def __init__(
243245
self.dataloader = []
244246
self.nsamples = nsamples
245247

248+
def prepare_layer_wise(self, model_path):
249+
import os
250+
251+
from neural_compressor.torch.algorithms.layer_wise import LWQ_WORKSPACE, get_path, register_weight_hooks
252+
253+
os.makedirs(LWQ_WORKSPACE, exist_ok=True)
254+
if model_path == "":
255+
model_path = self.model.path
256+
assert model_path, "model_path should not be None."
257+
self.model_path = get_path(model_path)
258+
register_weight_hooks(
259+
self.model, self.model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE
260+
)
261+
246262
def get_full_layer_name(self, sub_layer_name, block_idx):
247263
transformer_name = self.gptq_related_blocks["transformers_name"]
248264
return ".".join([transformer_name, str(block_idx), sub_layer_name])
@@ -413,7 +429,6 @@ def execute_quantization(self, means=None, stds=None):
413429
# Step1: prepare quantization (calibration datasets)
414430

415431
logger.info("Begin ====>")
416-
model_path = self.model_path
417432

418433
# Step2: run gptq quantization in a transformer block-wise manner.
419434
gptq_config = {}
@@ -450,7 +465,7 @@ def execute_quantization(self, means=None, stds=None):
450465
if self.use_layer_wise: # pragma: no cover
451466
from neural_compressor.torch.algorithms.layer_wise import load_value
452467

453-
W = load_value(self.model, full_layer_name + ".weight", model_path)
468+
W = load_value(self.model, full_layer_name + ".weight", self.model_path)
454469
else:
455470
W = sub_layers[layer_name].weight.data.clone()
456471

@@ -489,7 +504,7 @@ def tmp(_, inp, out):
489504
from neural_compressor.torch.algorithms.layer_wise import load_value
490505

491506
full_layer_name = self.get_full_layer_name(layer_name, block_idx)
492-
W = load_value(self.model, full_layer_name + ".weight", model_path)
507+
W = load_value(self.model, full_layer_name + ".weight", self.model_path)
493508
else:
494509
W = sub_layers[layer_name].weight.data.clone()
495510
accelerator.mark_step()
@@ -518,7 +533,7 @@ def tmp(_, inp, out):
518533
if n == "weight":
519534
set_module_tensor_to_device(self.model, param_name, self.device, Q)
520535
else:
521-
value = load_value(self.model, param_name, model_path)
536+
value = load_value(self.model, param_name, self.model_path)
522537
set_module_tensor_to_device(self.model, param_name, self.device, value)
523538
# sub_layer.weight.data = Q
524539
torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
@@ -562,7 +577,13 @@ def tmp(_, inp, out):
562577
gptq_perm = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"]
563578
else:
564579
gptq_perm = None
565-
Q = sub_layers[layer_name].weight.data
580+
if self.use_layer_wise:
581+
state_dict = torch.load(LWQ_WORKSPACE + f"/{self.get_full_layer_name(layer_name, block_idx)}.pt")
582+
Q = state_dict["weight"].data
583+
bias = state_dict["bias"] if "bias" in state_dict.keys() else None
584+
585+
else:
586+
Q = sub_layers[layer_name].weight.data
566587
if weight_config_this_layer["act_order"]:
567588
Q.copy_(Q[:, gptq_perm])
568589
if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D):
@@ -591,18 +612,21 @@ def tmp(_, inp, out):
591612
scale = scale.t_().contiguous()
592613
zp = zp.t_().contiguous() if zp is not None else zp
593614

615+
if not self.use_layer_wise:
616+
bias = sub_layers[layer_name].bias
617+
594618
new_module = WeightOnlyLinear(
595619
in_features,
596620
out_features,
597621
dtype=weight_config_this_layer["dtype"],
598622
bits=weight_config_this_layer["bits"],
599623
group_size=weight_config_this_layer["group_size"],
600624
zp=gptq_zp is not None,
601-
bias=sub_layers[layer_name].bias is not None,
625+
bias=bias is not None,
602626
g_idx=gptq_perm is not None,
603627
device=self.device,
604628
)
605-
new_module.pack(int_weight, gptq_scale, gptq_zp, sub_layers[layer_name].bias, gptq_perm)
629+
new_module.pack(int_weight, gptq_scale, gptq_zp, bias, gptq_perm)
606630
set_module(transformer_block, layer_name, new_module)
607631
del gptq_for_this_block
608632
torch.cuda.empty_cache()
@@ -1019,8 +1043,10 @@ def prepare(
10191043
def convert(self, model, *args, **kwargs):
10201044
self.gptq_quantizer.model = model
10211045
self.gptq_quantizer.remove_prepare_for_calibration()
1046+
10221047
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
1023-
q_model = q_model.to(self.model_device)
1048+
if not self.gptq_quantizer.use_layer_wise:
1049+
q_model = q_model.to(self.model_device)
10241050
q_model.gptq_config = gptq_config
10251051
logger.info("GPTQ quantizing done.")
10261052
return q_model

0 commit comments

Comments
 (0)