@@ -653,6 +653,7 @@ def _qwen2_model_forward(
653
653
inputs_embeds = self .embed_tokens (input_ids )
654
654
655
655
batch_size , seq_length = inputs_embeds .shape [:2 ]
656
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
656
657
657
658
past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
658
659
if cache_position is None :
@@ -677,6 +678,9 @@ def _qwen2_model_forward(
677
678
position_embeddings = self .rotary_emb (hidden_states , position_ids )
678
679
679
680
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
681
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
682
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
683
+ max_input_lens = input_lens .max ().item ()
680
684
681
685
if past_key_values_length == 0 and past_key_values is not None :
682
686
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -712,6 +716,9 @@ def _qwen2_model_forward(
712
716
cache_position = cache_position ,
713
717
position_embeddings = position_embeddings ,
714
718
input_lens = input_lens ,
719
+ max_input_lens = max_input_lens ,
720
+ seq_len_tensor = seq_len_tensor ,
721
+ query_len_tensor = query_len_tensor ,
715
722
** kwargs ,
716
723
)
717
724
0 commit comments