Skip to content

Commit 070a0dc

Browse files
jiqing-fengecharlaix
andauthoredMar 8, 2024
Update optimum/intel/ipex/modeling_base.py
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent e3a7024 commit 070a0dc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed
 

‎optimum/intel/ipex/modeling_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def forward(
501501
position_ids = attention_mask.long().cumsum(-1) - 1
502502
position_ids.masked_fill_(attention_mask == 0, 1)
503503
if past_key_values:
504-
position_ids = position_ids[:, -1].unsqueeze(0)
504+
position_ids = position_ids[:, -1].unsqueeze(-1)
505505

506506
if "position_ids" in self.input_names or not self.input_names:
507507
inputs["position_ids"] = position_ids

0 commit comments

Comments
 (0)