Skip to content

Commit 18b2a6a

Browse files
committed
fix ipex attn init
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 9af46d1 commit 18b2a6a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

optimum/exporters/ipex/modeling_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def forward(
682682

683683
class _IPEXLlamaAttention(_IPEXAttention):
684684
def __init__(self, module, device, config) -> None:
685-
super().__init__(module, config)
685+
super().__init__(module, device, config)
686686
if getattr(config, "quantization_config", None) is None:
687687
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
688688
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
@@ -725,7 +725,7 @@ def rope(self, query, key, **kwargs):
725725
class _IPEXFalconAttention(_IPEXAttention):
726726
def __init__(self, module, device, config):
727727
self.num_key_value_heads = config.num_key_value_heads
728-
super().__init__(module, config)
728+
super().__init__(module, device, config)
729729
self.q_slice = self.head_dim * config.num_kv_heads
730730
self.k_slice = self.q_slice + self.head_dim
731731
self.v_slice = self.k_slice + self.head_dim
@@ -752,7 +752,7 @@ def rope(self, query, key, **kwargs):
752752
class _IPEXGPT2Attention(_IPEXAttention):
753753
def __init__(self, module, device, config) -> None:
754754
self.num_key_value_heads = config.num_key_value_heads
755-
super().__init__(module, config)
755+
super().__init__(module, device, config)
756756
if getattr(config, "quantization_config", None):
757757
_remove_hooks_for_ipex(self, True)
758758

0 commit comments

Comments
 (0)