Skip to content

Commit acfd0ce

Browse files
committed
fix position_id init for qwen2
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 00e6bf3 commit acfd0ce

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

optimum/exporters/ipex/modeling_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,11 @@ def _qwen2_model_forward(
644644
)
645645

646646
if position_ids is None:
647-
position_ids = cache_position.unsqueeze(0)
647+
device = input_ids.device if input_ids is not None else inputs_embeds.device
648+
position_ids = torch.arange(
649+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
650+
)
651+
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
648652

649653
causal_mask = self._update_causal_mask(
650654
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions

0 commit comments

Comments
 (0)