Skip to content

Commit 5dcde91

Browse files
fix bug of lwq gtpq (#2128)
Signed-off-by: n1ck-guo <heng.guo@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a0977e2 commit 5dcde91

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

neural_compressor/adaptor/torch_utils/gptq.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -718,10 +718,12 @@ def tmp(_, inp, out):
718718
for n, p in sub_layer.named_parameters():
719719
param_name = full_layer_name + "." + n
720720
if n == "weight":
721-
set_module_tensor_to_device(self.model, param_name, self.device, Q)
721+
set_module_tensor_to_device(self.model, param_name, self.device, Q, dtype=Q.dtype)
722722
else:
723723
value = load_value(self.model, param_name, model_path)
724-
set_module_tensor_to_device(self.model, param_name, self.device, value)
724+
set_module_tensor_to_device(
725+
self.model, param_name, self.device, value, dtype=value.dtype
726+
)
725727
# sub_layer.weight.data = Q
726728
torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
727729
clean_module_weight(sub_layer)
@@ -745,6 +747,8 @@ def tmp(_, inp, out):
745747
for j in range(len(self.dataloader)):
746748
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
747749
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
750+
# breakpoint()
751+
# transformer_block = transformer_block.to(getattr(torch, self.model.config.torch_dtype))
748752
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)
749753
out = self.track_hidden_states(out)
750754
outs.append(out)

neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def load_module(model, module_name, path, device="cpu"):
221221
for n, p in module.named_parameters():
222222
param_name = module_name + "." + n
223223
value = load_value(model, param_name, path)
224-
set_module_tensor_to_device(model, param_name, device, value)
224+
set_module_tensor_to_device(model, param_name, device, value, dtype=value.dtype)
225225

226226

227227
def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None):
@@ -239,7 +239,7 @@ def hook(module, input):
239239
value = state_dict[n]
240240
else:
241241
value = load_value(model, param_name, path)
242-
set_module_tensor_to_device(model, param_name, device, value)
242+
set_module_tensor_to_device(model, param_name, device, value, dtype=value.dtype)
243243

244244
return hook
245245

0 commit comments

Comments
 (0)