@@ -682,7 +682,7 @@ def forward(
682
682
683
683
class _IPEXLlamaAttention (_IPEXAttention ):
684
684
def __init__ (self , module , device , config ) -> None :
685
- super ().__init__ (module , config )
685
+ super ().__init__ (module , device , config )
686
686
if getattr (config , "quantization_config" , None ) is None :
687
687
concat_weight = torch .concat ([self .q_proj .weight , self .k_proj .weight , self .v_proj .weight ]).contiguous ()
688
688
bias_list = [bias for bias in [self .q_proj .bias , self .k_proj .bias , self .v_proj .bias ] if bias ]
@@ -725,7 +725,7 @@ def rope(self, query, key, **kwargs):
725
725
class _IPEXFalconAttention (_IPEXAttention ):
726
726
def __init__ (self , module , device , config ):
727
727
self .num_key_value_heads = config .num_key_value_heads
728
- super ().__init__ (module , config )
728
+ super ().__init__ (module , device , config )
729
729
self .q_slice = self .head_dim * config .num_kv_heads
730
730
self .k_slice = self .q_slice + self .head_dim
731
731
self .v_slice = self .k_slice + self .head_dim
@@ -752,7 +752,7 @@ def rope(self, query, key, **kwargs):
752
752
class _IPEXGPT2Attention (_IPEXAttention ):
753
753
def __init__ (self , module , device , config ) -> None :
754
754
self .num_key_value_heads = config .num_key_value_heads
755
- super ().__init__ (module , config )
755
+ super ().__init__ (module , device , config )
756
756
if getattr (config , "quantization_config" , None ):
757
757
_remove_hooks_for_ipex (self , True )
758
758
0 commit comments