@@ -599,7 +599,7 @@ def __init__(self, module, config) -> None:
599
599
super ().__init__ ()
600
600
_setattr_from_module (self , module )
601
601
self .config = config
602
- self .module_device = next ( module . parameters ()) .device
602
+ self .module_device = config .device
603
603
self .num_groups = self .num_heads // self .num_key_value_heads
604
604
self .kv_head_mapping = torch .arange (
605
605
0 , self .num_key_value_heads , dtype = torch .int32 , device = self .module_device
@@ -779,7 +779,7 @@ def __init__(self, module, config) -> None:
779
779
super ().__init__ ()
780
780
_setattr_from_module (self , module )
781
781
self .config = config
782
- self .module_device = next ( module . parameters ()) .device
782
+ self .module_device = config .device
783
783
if getattr (config , "quantization_config" , None ) is None :
784
784
if self .module_device .type == "cpu" :
785
785
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
@@ -812,7 +812,7 @@ def __init__(self, module, config) -> None:
812
812
super ().__init__ ()
813
813
_setattr_from_module (self , module )
814
814
self .config = config
815
- self .module_device = next ( module . parameters ()) .device
815
+ self .module_device = config .device
816
816
if getattr (config , "quantization_config" , None ) is None :
817
817
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
818
818
if self .module_device .type == "cpu" :
@@ -911,7 +911,7 @@ class _IPEXIntermediate(nn.Module):
911
911
def __init__ (self , module , config ):
912
912
super ().__init__ ()
913
913
_setattr_from_module (self , module )
914
- self .module_device = next ( module . parameters ()) .device
914
+ self .module_device = config .device
915
915
if getattr (config , "quantization_config" , None ) is None :
916
916
if self .module_device .type == "cpu" :
917
917
self .linear_gelu = LinearGelu (module .dense )
0 commit comments