Skip to content

Commit b0cec9c

Browse files
committed
set actual device
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 9a7e931 commit b0cec9c

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

optimum/exporters/ipex/model_patcher.py

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _patch_vit_model(model):
133133

134134

135135
def _patch_model(model):
136+
setattr(model.config, "device", model.device)
136137
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
137138
raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching")
138139
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(

optimum/exporters/ipex/modeling_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def __init__(self, module, config) -> None:
599599
super().__init__()
600600
_setattr_from_module(self, module)
601601
self.config = config
602-
self.module_device = next(module.parameters()).device
602+
self.module_device = config.device
603603
self.num_groups = self.num_heads // self.num_key_value_heads
604604
self.kv_head_mapping = torch.arange(
605605
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
@@ -779,7 +779,7 @@ def __init__(self, module, config) -> None:
779779
super().__init__()
780780
_setattr_from_module(self, module)
781781
self.config = config
782-
self.module_device = next(module.parameters()).device
782+
self.module_device = config.device
783783
if getattr(config, "quantization_config", None) is None:
784784
if self.module_device.type == "cpu":
785785
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
@@ -812,7 +812,7 @@ def __init__(self, module, config) -> None:
812812
super().__init__()
813813
_setattr_from_module(self, module)
814814
self.config = config
815-
self.module_device = next(module.parameters()).device
815+
self.module_device = config.device
816816
if getattr(config, "quantization_config", None) is None:
817817
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
818818
if self.module_device.type == "cpu":
@@ -911,7 +911,7 @@ class _IPEXIntermediate(nn.Module):
911911
def __init__(self, module, config):
912912
super().__init__()
913913
_setattr_from_module(self, module)
914-
self.module_device = next(module.parameters()).device
914+
self.module_device = config.device
915915
if getattr(config, "quantization_config", None) is None:
916916
if self.module_device.type == "cpu":
917917
self.linear_gelu = LinearGelu(module.dense)

0 commit comments

Comments
 (0)