Skip to content

Commit 4092311

Browse files
authored
refine autoround export (#1711)
Signed-off-by: changwangss <chang1.wang@intel.com>
1 parent 7ee7215 commit 4092311

File tree

1 file changed

+48
-14
lines changed

1 file changed

+48
-14
lines changed

neural_compressor/model/torch_model.py

+48-14
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,7 @@ def export_compressed_model(
498498
gptq_config = self.gptq_config if hasattr(self, "gptq_config") else {}
499499

500500
autoround_config = self.autoround_config if hasattr(self, "autoround_config") else {}
501-
502-
if gptq_config or (autoround_config and device == "xpu"):
501+
if gptq_config:
503502
for k, v in weight_config.items():
504503
logger.debug(f"Compressing {k} on device {device}")
505504
if v["dtype"] == "fp32":
@@ -558,19 +557,54 @@ def export_compressed_model(
558557
)
559558
new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm)
560559
set_module(self.model, k, new_module)
561-
elif autoround_config and (device == "cpu" or device == "auto"):
562-
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
560+
elif autoround_config:
561+
if device == "xpu":
562+
for k, v in weight_config.items():
563+
logger.debug(f"Compressing {k} on device {device}")
564+
if v["dtype"] == "fp32":
565+
continue
566+
else:
567+
dtype = v["dtype"]
568+
num_bits = v["bits"]
569+
group_size = v["group_size"]
570+
scheme = v["scheme"]
571+
m = fetch_module(self.model, k)
572+
autoround_conf = autoround_config[k]
573+
fp32_weight = m.weight.data
574+
autoround_scale = torch.tensor(autoround_conf["scale"], dtype=torch.float32)
575+
autoround_zp = None if scheme == "sym" else torch.tensor(autoround_conf["zero"], dtype=torch.int32)
576+
int_weight = quant_weight_w_scale(fp32_weight, autoround_scale, autoround_zp, group_size)
577+
int_weight = int_weight.type(torch.int32)
578+
new_module = WeightOnlyLinear(
579+
m.in_features,
580+
m.out_features,
581+
num_bits,
582+
group_size,
583+
dtype=dtype,
584+
zp=autoround_zp is not None,
585+
bias=m.bias is not None,
586+
g_idx=None,
587+
compression_dtype=compression_dtype,
588+
compression_dim=compression_dim,
589+
scale_dtype=scale_dtype,
590+
device=device,
591+
use_optimum_format=use_optimum_format,
592+
)
593+
new_module.pack(int_weight, autoround_scale, autoround_zp, m.bias, None)
594+
set_module(self.model, k, new_module)
595+
else:
596+
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
563597

564-
self.model = pack_model(
565-
self.model,
566-
weight_config=autoround_config,
567-
enable_full_range=enable_full_range,
568-
compression_dtype=compression_dtype,
569-
compression_dim=compression_dim,
570-
device=device,
571-
use_optimum_format=use_optimum_format,
572-
inplace=True,
573-
)
598+
self.model = pack_model(
599+
self.model,
600+
weight_config=autoround_config,
601+
enable_full_range=enable_full_range,
602+
compression_dtype=compression_dtype,
603+
compression_dim=compression_dim,
604+
device=device,
605+
use_optimum_format=use_optimum_format,
606+
inplace=True,
607+
)
574608
else:
575609
for k, v in weight_config.items():
576610
logger.debug(f"Compressing {k} on device {device}")

0 commit comments

Comments
 (0)