@@ -179,8 +179,8 @@ def _llama_model_forward(
179
179
180
180
past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
181
181
182
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
182
183
if position_ids is None :
183
- device = input_ids .device if input_ids is not None else inputs_embeds .device
184
184
position_ids = torch .arange (
185
185
past_key_values_length , seq_length + past_key_values_length , dtype = torch .long , device = device
186
186
)
@@ -200,6 +200,9 @@ def _llama_model_forward(
200
200
position_embeddings = self .rotary_emb (hidden_states , position_ids )
201
201
202
202
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
203
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
204
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
205
+ max_input_lens = input_lens .max ().item ()
203
206
204
207
if past_key_values_length == 0 and past_key_values is not None :
205
208
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -235,6 +238,9 @@ def _llama_model_forward(
235
238
use_cache = use_cache ,
236
239
position_embeddings = position_embeddings ,
237
240
input_lens = input_lens ,
241
+ max_input_lens = max_input_lens ,
242
+ seq_len_tensor = seq_len_tensor ,
243
+ query_len_tensor = query_len_tensor ,
238
244
)
239
245
240
246
hidden_states = layer_outputs [0 ]
@@ -303,10 +309,11 @@ def _falcon_model_forward(
303
309
304
310
past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
305
311
batch_size , seq_length , _ = inputs_embeds .shape
312
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
306
313
307
314
if cache_position is None :
308
315
cache_position = torch .arange (
309
- past_key_values_length , past_key_values_length + seq_length , device = inputs_embeds . device
316
+ past_key_values_length , past_key_values_length + seq_length , device = device
310
317
)
311
318
312
319
if position_ids is None :
@@ -323,6 +330,9 @@ def _falcon_model_forward(
323
330
position_embeddings = self .rotary_emb (hidden_states , position_ids )
324
331
325
332
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
333
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
334
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
335
+ max_input_lens = input_lens .max ().item ()
326
336
327
337
if past_key_values_length == 0 and past_key_values is not None :
328
338
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -365,6 +375,9 @@ def _falcon_model_forward(
365
375
cache_position = cache_position ,
366
376
position_embeddings = position_embeddings ,
367
377
input_lens = input_lens ,
378
+ max_input_lens = max_input_lens ,
379
+ seq_len_tensor = seq_len_tensor ,
380
+ query_len_tensor = query_len_tensor ,
368
381
)
369
382
370
383
hidden_states = outputs [0 ]
@@ -459,6 +472,9 @@ def _gpt2_model_forward(
459
472
hidden_states = self .drop (hidden_states )
460
473
461
474
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
475
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
476
+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
477
+ max_input_lens = input_lens .max ().item ()
462
478
463
479
if past_length == 0 and past_key_values is not None :
464
480
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -494,6 +510,9 @@ def _gpt2_model_forward(
494
510
use_cache = use_cache ,
495
511
output_attentions = output_attentions ,
496
512
input_lens = input_lens ,
513
+ max_input_lens = max_input_lens ,
514
+ seq_len_tensor = seq_len_tensor ,
515
+ query_len_tensor = query_len_tensor ,
497
516
)
498
517
499
518
hidden_states = outputs [0 ]
@@ -635,7 +654,19 @@ def has_flash_attn(self):
635
654
return is_torch_version (">" , "2.5.99" )
636
655
637
656
def attention_interface (
638
- self , query , key_cache , value_cache , key , value , past_key_value , attention_mask , input_lens , past_len
657
+ self ,
658
+ query ,
659
+ key_cache ,
660
+ value_cache ,
661
+ key ,
662
+ value ,
663
+ past_key_value ,
664
+ attention_mask ,
665
+ input_lens ,
666
+ past_len ,
667
+ seq_len_tensor ,
668
+ query_len_tensor ,
669
+ max_input_lens ,
639
670
):
640
671
if past_key_value is None :
641
672
n_rep = query .shape [1 ] // key .shape [1 ]
@@ -654,18 +685,13 @@ def attention_interface(
654
685
self .use_sdpa = True
655
686
elif self .has_flash_attn ():
656
687
attn_output = torch .empty_like (query )
657
- seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
658
- query_len_tensor = (
659
- seq_len_tensor if past_len == 0 else torch .arange (seq_len_tensor .shape [0 ], device = query .device ).int ()
660
- )
661
- max_input_lens = input_lens .max ().item ()
662
688
query_max_len = max_input_lens if past_len == 0 else 1
663
689
PagedAttention .flash_attn_varlen_func (
664
690
attn_output ,
665
691
query .contiguous () if query .device .type == "xpu" else query ,
666
692
key_cache ,
667
693
value_cache ,
668
- query_len_tensor ,
694
+ seq_len_tensor if past_len == 0 else query_len_tensor ,
669
695
seq_len_tensor ,
670
696
query_max_len ,
671
697
max_input_lens ,
@@ -677,7 +703,6 @@ def attention_interface(
677
703
elif past_len == 0 :
678
704
# prefill, remove padding
679
705
attn_output = torch .empty_like (query )
680
- seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
681
706
varlen_attention (
682
707
query .contiguous () if query .device .type == "xpu" else query ,
683
708
key .contiguous () if key .device .type == "xpu" else key ,
@@ -726,6 +751,9 @@ def forward(
726
751
if past_key_value is None and kwargs .get ("layer_past" , None ) is not None :
727
752
past_key_value = kwargs .pop ("layer_past" , None )
728
753
input_lens = kwargs .pop ("input_lens" , None )
754
+ seq_len_tensor = kwargs .pop ("seq_len_tensor" , None )
755
+ query_len_tensor = kwargs .pop ("query_len_tensor" , None )
756
+ max_input_lens = kwargs .pop ("max_input_lens" , 0 )
729
757
past_len = 0
730
758
if past_key_value is not None :
731
759
past_len = past_key_value .get_seq_length ()
@@ -737,7 +765,18 @@ def forward(
737
765
key_cache , value_cache = past_key_value .update (key , value , self .layer_idx , attention_mask , input_lens )
738
766
739
767
attn_output = self .attention_interface (
740
- query , key_cache , value_cache , key , value , past_key_value , attention_mask , input_lens , past_len
768
+ query ,
769
+ key_cache ,
770
+ value_cache ,
771
+ key ,
772
+ value ,
773
+ past_key_value ,
774
+ attention_mask ,
775
+ input_lens ,
776
+ past_len ,
777
+ seq_len_tensor ,
778
+ query_len_tensor ,
779
+ max_input_lens ,
741
780
)
742
781
743
782
attn_output = self .postprocess_attention_output (attn_output )
0 commit comments