Skip to content

Commit dd811f9

Browse files
committed
fix mlp class init
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 6d8a969 commit dd811f9

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

optimum/exporters/ipex/modeling_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,8 @@ class _IPEXLlamaDecoderLayer(nn.Module):
847847
def __init__(self, module, device, config):
848848
super().__init__()
849849
_setattr_from_module(self, module)
850-
self.self_attn = _IPEXLlamaAttention(module.self_attn, config)
851-
self.mlp = _IPEXLlamaMLP(module.mlp, config)
850+
self.self_attn = _IPEXLlamaAttention(module.self_attn, device, config)
851+
self.mlp = _IPEXLlamaMLP(module.mlp, device, config)
852852
if getattr(config, "quantization_config", None):
853853
_remove_hooks_for_ipex(self, True)
854854

@@ -882,8 +882,8 @@ class _IPEXFalconDecoderLayer(nn.Module):
882882
def __init__(self, module, device, config):
883883
super().__init__()
884884
_setattr_from_module(self, module)
885-
self.self_attention = _IPEXFalconAttention(module.self_attention, config)
886-
self.mlp = _IPEXFalconMLP(module.mlp, config)
885+
self.self_attention = _IPEXFalconAttention(module.self_attention, device, config)
886+
self.mlp = _IPEXFalconMLP(module.mlp, device, config)
887887
if getattr(config, "quantization_config", None):
888888
_remove_hooks_for_ipex(self, True)
889889

0 commit comments

Comments
 (0)