@@ -847,8 +847,8 @@ class _IPEXLlamaDecoderLayer(nn.Module):
847
847
def __init__ (self , module , device , config ):
848
848
super ().__init__ ()
849
849
_setattr_from_module (self , module )
850
- self .self_attn = _IPEXLlamaAttention (module .self_attn , config )
851
- self .mlp = _IPEXLlamaMLP (module .mlp , config )
850
+ self .self_attn = _IPEXLlamaAttention (module .self_attn , device , config )
851
+ self .mlp = _IPEXLlamaMLP (module .mlp , device , config )
852
852
if getattr (config , "quantization_config" , None ):
853
853
_remove_hooks_for_ipex (self , True )
854
854
@@ -882,8 +882,8 @@ class _IPEXFalconDecoderLayer(nn.Module):
882
882
def __init__ (self , module , device , config ):
883
883
super ().__init__ ()
884
884
_setattr_from_module (self , module )
885
- self .self_attention = _IPEXFalconAttention (module .self_attention , config )
886
- self .mlp = _IPEXFalconMLP (module .mlp , config )
885
+ self .self_attention = _IPEXFalconAttention (module .self_attention , device , config )
886
+ self .mlp = _IPEXFalconMLP (module .mlp , device , config )
887
887
if getattr (config , "quantization_config" , None ):
888
888
_remove_hooks_for_ipex (self , True )
889
889
0 commit comments