Skip to content

Commit d5100b4

Browse files
committedFeb 19, 2025
fix conflict CI
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 081ff45 commit d5100b4

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed
 

‎optimum/exporters/ipex/modeling_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ def _qwen2_model_forward(
653653
inputs_embeds = self.embed_tokens(input_ids)
654654

655655
batch_size, seq_length = inputs_embeds.shape[:2]
656+
device = input_ids.device if input_ids is not None else inputs_embeds.device
656657

657658
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
658659
if cache_position is None:
@@ -677,6 +678,9 @@ def _qwen2_model_forward(
677678
position_embeddings = self.rotary_emb(hidden_states, position_ids)
678679

679680
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
681+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
682+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
683+
max_input_lens = input_lens.max().item()
680684

681685
if past_key_values_length == 0 and past_key_values is not None:
682686
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -712,6 +716,9 @@ def _qwen2_model_forward(
712716
cache_position=cache_position,
713717
position_embeddings=position_embeddings,
714718
input_lens=input_lens,
719+
max_input_lens=max_input_lens,
720+
seq_len_tensor=seq_len_tensor,
721+
query_len_tensor=query_len_tensor,
715722
**kwargs,
716723
)
717724

0 commit comments

Comments
 (0)