Skip to content

Commit bb51139

Browse files
authored
fix autotp linear check (#1062)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 41d9a37 commit bb51139

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

optimum/exporters/ipex/modeling_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -664,9 +664,9 @@ def __init__(self, module, config) -> None:
664664
if use_bias:
665665
concat_bias = torch.concat(bias_list, 0).contiguous()
666666
self.concat_linear.bias = nn.Parameter(concat_bias)
667-
self.q_slice = self.q_proj.out_features
668-
self.k_slice = self.q_slice + self.k_proj.out_features
669-
self.v_slice = self.k_slice + self.v_proj.out_features
667+
self.q_slice = self.q_proj.weight.shape[0]
668+
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
669+
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
670670
if self.module_device.type == "cpu":
671671
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
672672
self.mha_linear_add = LinearAdd(module.o_proj)

0 commit comments

Comments
 (0)