@@ -206,8 +206,8 @@ def _llama_model_forward(
206
206
207
207
past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
208
208
209
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
209
210
if position_ids is None :
210
- device = input_ids .device if input_ids is not None else inputs_embeds .device
211
211
position_ids = torch .arange (
212
212
past_key_values_length , seq_length + past_key_values_length , dtype = torch .long , device = device
213
213
)
@@ -227,6 +227,9 @@ def _llama_model_forward(
227
227
position_embeddings = self .rotary_emb (hidden_states , position_ids )
228
228
229
229
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
230
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
231
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
232
+ max_input_lens = input_lens .max ().item ()
230
233
231
234
if past_key_values_length == 0 and past_key_values is not None :
232
235
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -262,6 +265,9 @@ def _llama_model_forward(
262
265
use_cache = use_cache ,
263
266
position_embeddings = position_embeddings ,
264
267
input_lens = input_lens ,
268
+ max_input_lens = max_input_lens ,
269
+ seq_len_tensor = seq_len_tensor ,
270
+ query_len_tensor = query_len_tensor ,
265
271
)
266
272
267
273
hidden_states = layer_outputs [0 ]
@@ -330,11 +336,10 @@ def _falcon_model_forward(
330
336
331
337
past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
332
338
batch_size , seq_length , _ = inputs_embeds .shape
339
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
333
340
334
341
if cache_position is None :
335
- cache_position = torch .arange (
336
- past_key_values_length , past_key_values_length + seq_length , device = inputs_embeds .device
337
- )
342
+ cache_position = torch .arange (past_key_values_length , past_key_values_length + seq_length , device = device )
338
343
339
344
if position_ids is None :
340
345
position_ids = cache_position .unsqueeze (0 ).repeat_interleave (input_ids .shape [0 ], 0 )
@@ -350,6 +355,9 @@ def _falcon_model_forward(
350
355
position_embeddings = self .rotary_emb (hidden_states , position_ids )
351
356
352
357
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
358
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
359
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
360
+ max_input_lens = input_lens .max ().item ()
353
361
354
362
if past_key_values_length == 0 and past_key_values is not None :
355
363
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -392,6 +400,9 @@ def _falcon_model_forward(
392
400
cache_position = cache_position ,
393
401
position_embeddings = position_embeddings ,
394
402
input_lens = input_lens ,
403
+ max_input_lens = max_input_lens ,
404
+ seq_len_tensor = seq_len_tensor ,
405
+ query_len_tensor = query_len_tensor ,
395
406
)
396
407
397
408
hidden_states = outputs [0 ]
@@ -486,6 +497,9 @@ def _gpt2_model_forward(
486
497
hidden_states = self .drop (hidden_states )
487
498
488
499
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
500
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
501
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
502
+ max_input_lens = input_lens .max ().item ()
489
503
490
504
if past_length == 0 and past_key_values is not None :
491
505
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -521,6 +535,9 @@ def _gpt2_model_forward(
521
535
use_cache = use_cache ,
522
536
output_attentions = output_attentions ,
523
537
input_lens = input_lens ,
538
+ max_input_lens = max_input_lens ,
539
+ seq_len_tensor = seq_len_tensor ,
540
+ query_len_tensor = query_len_tensor ,
524
541
)
525
542
526
543
hidden_states = outputs [0 ]
@@ -591,6 +608,7 @@ def _qwen2_model_forward(
591
608
inputs_embeds = self .embed_tokens (input_ids )
592
609
593
610
batch_size , seq_length = inputs_embeds .shape [:2 ]
611
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
594
612
595
613
past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
596
614
if cache_position is None :
@@ -615,6 +633,9 @@ def _qwen2_model_forward(
615
633
position_embeddings = self .rotary_emb (hidden_states , position_ids )
616
634
617
635
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
636
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
637
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
638
+ max_input_lens = input_lens .max ().item ()
618
639
619
640
if past_key_values_length == 0 and past_key_values is not None :
620
641
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -650,6 +671,9 @@ def _qwen2_model_forward(
650
671
cache_position = cache_position ,
651
672
position_embeddings = position_embeddings ,
652
673
input_lens = input_lens ,
674
+ max_input_lens = max_input_lens ,
675
+ seq_len_tensor = seq_len_tensor ,
676
+ query_len_tensor = query_len_tensor ,
653
677
** kwargs ,
654
678
)
655
679
@@ -704,14 +728,26 @@ def postprocess_attention_output(self, attn_output):
704
728
return attn_output
705
729
706
730
# Maybe removed after torch 2.6 released
707
- def has_flash_attn (self , query ):
708
- if query . device .type == "cpu" :
731
+ def has_flash_attn (self ):
732
+ if self . module_device .type == "cpu" :
709
733
return is_torch_version (">" , "2.4.99" )
710
- elif query . device .type == "xpu" :
734
+ elif self . module_device .type == "xpu" :
711
735
return is_torch_version (">" , "2.5.99" )
712
736
713
737
def attention_interface (
714
- self , query , key_cache , value_cache , key , value , past_key_value , attention_mask , input_lens , past_len
738
+ self ,
739
+ query ,
740
+ key_cache ,
741
+ value_cache ,
742
+ key ,
743
+ value ,
744
+ past_key_value ,
745
+ attention_mask ,
746
+ input_lens ,
747
+ past_len ,
748
+ seq_len_tensor ,
749
+ query_len_tensor ,
750
+ max_input_lens ,
715
751
):
716
752
if past_key_value is None :
717
753
n_rep = query .shape [1 ] // key .shape [1 ]
@@ -728,20 +764,19 @@ def attention_interface(
728
764
is_causal = True ,
729
765
)
730
766
self .use_sdpa = True
731
- elif self .has_flash_attn (query ):
767
+ elif self .has_flash_attn ():
732
768
attn_output = torch .empty_like (query )
733
- seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
734
- query_len_tensor = seq_len_tensor if past_len == 0 else torch .arange (seq_len_tensor .shape [0 ]).int ()
735
- query_max_len = input_lens .max () if past_len == 0 else 1
769
+ query_len_tensor = seq_len_tensor if past_len == 0 else query_len_tensor
770
+ query_max_len = max_input_lens if past_len == 0 else 1
736
771
PagedAttention .flash_attn_varlen_func (
737
772
attn_output ,
738
773
query .contiguous () if query .device .type == "xpu" else query ,
739
- key_cache . contiguous () if key_cache . device . type == "xpu" else key_cache ,
740
- value_cache . contiguous () if value_cache . device . type == "xpu" else value_cache ,
774
+ key_cache ,
775
+ value_cache ,
741
776
query_len_tensor ,
742
777
seq_len_tensor ,
743
778
query_max_len ,
744
- input_lens . max () ,
779
+ max_input_lens ,
745
780
1.0 / math .sqrt (self .head_dim ),
746
781
True ,
747
782
past_key_value .block_tables ,
@@ -750,7 +785,6 @@ def attention_interface(
750
785
elif past_len == 0 :
751
786
# prefill, remove padding
752
787
attn_output = torch .empty_like (query )
753
- seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
754
788
varlen_attention (
755
789
query .contiguous () if query .device .type == "xpu" else query ,
756
790
key .contiguous () if key .device .type == "xpu" else key ,
@@ -799,6 +833,9 @@ def forward(
799
833
if past_key_value is None and kwargs .get ("layer_past" , None ) is not None :
800
834
past_key_value = kwargs .pop ("layer_past" , None )
801
835
input_lens = kwargs .pop ("input_lens" , None )
836
+ seq_len_tensor = kwargs .pop ("seq_len_tensor" , None )
837
+ query_len_tensor = kwargs .pop ("query_len_tensor" , None )
838
+ max_input_lens = kwargs .pop ("max_input_lens" , 0 )
802
839
past_len = 0
803
840
if past_key_value is not None :
804
841
past_len = past_key_value .get_seq_length ()
@@ -810,7 +847,18 @@ def forward(
810
847
key_cache , value_cache = past_key_value .update (key , value , self .layer_idx , attention_mask , input_lens )
811
848
812
849
attn_output = self .attention_interface (
813
- query , key_cache , value_cache , key , value , past_key_value , attention_mask , input_lens , past_len
850
+ query ,
851
+ key_cache ,
852
+ value_cache ,
853
+ key ,
854
+ value ,
855
+ past_key_value ,
856
+ attention_mask ,
857
+ input_lens ,
858
+ past_len ,
859
+ seq_len_tensor ,
860
+ query_len_tensor ,
861
+ max_input_lens ,
814
862
)
815
863
816
864
attn_output = self .postprocess_attention_output (attn_output )
0 commit comments