Skip to content

Commit 9af46d1

Browse files
committed
fix class init
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 94cf35d commit 9af46d1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

optimum/exporters/ipex/modeling_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def forward(
681681

682682

683683
class _IPEXLlamaAttention(_IPEXAttention):
684-
def __init__(self, module, config) -> None:
684+
def __init__(self, module, device, config) -> None:
685685
super().__init__(module, config)
686686
if getattr(config, "quantization_config", None) is None:
687687
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
@@ -723,7 +723,7 @@ def rope(self, query, key, **kwargs):
723723

724724

725725
class _IPEXFalconAttention(_IPEXAttention):
726-
def __init__(self, module, config):
726+
def __init__(self, module, device, config):
727727
self.num_key_value_heads = config.num_key_value_heads
728728
super().__init__(module, config)
729729
self.q_slice = self.head_dim * config.num_kv_heads
@@ -750,7 +750,7 @@ def rope(self, query, key, **kwargs):
750750

751751

752752
class _IPEXGPT2Attention(_IPEXAttention):
753-
def __init__(self, module, config) -> None:
753+
def __init__(self, module, device, config) -> None:
754754
self.num_key_value_heads = config.num_key_value_heads
755755
super().__init__(module, config)
756756
if getattr(config, "quantization_config", None):
@@ -844,7 +844,7 @@ def forward(
844844

845845
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
846846
class _IPEXLlamaDecoderLayer(nn.Module):
847-
def __init__(self, module, config):
847+
def __init__(self, module, device, config):
848848
super().__init__()
849849
_setattr_from_module(self, module)
850850
self.self_attn = _IPEXLlamaAttention(module.self_attn, config)
@@ -879,7 +879,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
879879

880880

881881
class _IPEXFalconDecoderLayer(nn.Module):
882-
def __init__(self, module, config):
882+
def __init__(self, module, device, config):
883883
super().__init__()
884884
_setattr_from_module(self, module)
885885
self.self_attention = _IPEXFalconAttention(module.self_attention, config)

0 commit comments

Comments
 (0)