Skip to content

Commit 2f5beeb

Browse files
authored
Merge branch 'master' into cherry_pick_v1.20.0
2 parents a09f187 + 09ccf3c commit 2f5beeb

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

neural_compressor/transformers/models/modeling_auto.py

+5
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
226226

227227
# add quantization_config and save_low_bit to pretrained model dynamically
228228
model.device_map = device_map
229+
230+
# StaticCache's device is initialized by `hf_device_map` in `from_pretrained` method.
231+
if hasattr(model, "hf_device_map"):
232+
device_map = torch.device(device_map) if isinstance(device_map, str) else device_map
233+
model.hf_device_map = {"": device_map}
229234
model.quantization_config = quantization_config
230235

231236
model.save_pretrained = types.MethodType(save_low_bit, model)

0 commit comments

Comments
 (0)