Skip to content

Commit 09ccf3c

Browse files
Fix hf_device_map setting for transformers-like api (#2122)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6bb52cf commit 09ccf3c

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

neural_compressor/transformers/models/modeling_auto.py

+5
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
226226

227227
# add quantization_config and save_low_bit to pretrained model dynamically
228228
model.device_map = device_map
229+
230+
# StaticCache's device is initialized by `hf_device_map` in `from_pretrained` method.
231+
if hasattr(model, "hf_device_map"):
232+
device_map = torch.device(device_map) if isinstance(device_map, str) else device_map
233+
model.hf_device_map = {"": device_map}
229234
model.quantization_config = quantization_config
230235

231236
model.save_pretrained = types.MethodType(save_low_bit, model)

0 commit comments

Comments
 (0)