@@ -681,7 +681,7 @@ def forward(
681
681
682
682
683
683
class _IPEXLlamaAttention (_IPEXAttention ):
684
- def __init__ (self , module , config ) -> None :
684
+ def __init__ (self , module , device , config ) -> None :
685
685
super ().__init__ (module , config )
686
686
if getattr (config , "quantization_config" , None ) is None :
687
687
concat_weight = torch .concat ([self .q_proj .weight , self .k_proj .weight , self .v_proj .weight ]).contiguous ()
@@ -723,7 +723,7 @@ def rope(self, query, key, **kwargs):
723
723
724
724
725
725
class _IPEXFalconAttention (_IPEXAttention ):
726
- def __init__ (self , module , config ):
726
+ def __init__ (self , module , device , config ):
727
727
self .num_key_value_heads = config .num_key_value_heads
728
728
super ().__init__ (module , config )
729
729
self .q_slice = self .head_dim * config .num_kv_heads
@@ -750,7 +750,7 @@ def rope(self, query, key, **kwargs):
750
750
751
751
752
752
class _IPEXGPT2Attention (_IPEXAttention ):
753
- def __init__ (self , module , config ) -> None :
753
+ def __init__ (self , module , device , config ) -> None :
754
754
self .num_key_value_heads = config .num_key_value_heads
755
755
super ().__init__ (module , config )
756
756
if getattr (config , "quantization_config" , None ):
@@ -844,7 +844,7 @@ def forward(
844
844
845
845
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
846
846
class _IPEXLlamaDecoderLayer (nn .Module ):
847
- def __init__ (self , module , config ):
847
+ def __init__ (self , module , device , config ):
848
848
super ().__init__ ()
849
849
_setattr_from_module (self , module )
850
850
self .self_attn = _IPEXLlamaAttention (module .self_attn , config )
@@ -879,7 +879,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
879
879
880
880
881
881
class _IPEXFalconDecoderLayer (nn .Module ):
882
- def __init__ (self , module , config ):
882
+ def __init__ (self , module , device , config ):
883
883
super ().__init__ ()
884
884
_setattr_from_module (self , module )
885
885
self .self_attention = _IPEXFalconAttention (module .self_attention , config )
0 commit comments