File tree 2 files changed +8
-5
lines changed
2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change
1
+ import os
1
2
from typing import List , Optional , Tuple
2
3
3
4
import torch
4
5
from intel_extension_for_pytorch .llm .modules import PagedAttention
5
6
from transformers import Cache , PretrainedConfig
6
7
7
8
8
- # May need to tune based on sequence length and different models but default to 16 currently .
9
- BLOCK_SIZE = 16
9
+ # Recommend 16 on CPU and 64 on XPU .
10
+ BLOCK_SIZE = int ( os . environ . get ( "OI_PAGED_ATTN_BLOCK_SIZE" , 16 ))
10
11
11
12
12
13
class IPEXPagedCache (Cache ):
Original file line number Diff line number Diff line change @@ -652,17 +652,19 @@ def attention_interface(
652
652
is_causal = True ,
653
653
)
654
654
self .use_sdpa = True
655
- elif self .has_flash_attn (query ) and past_len == 0 :
655
+ elif self .has_flash_attn (query ):
656
656
attn_output = torch .empty_like (query )
657
657
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
658
+ query_len_tensor = seq_len_tensor if past_len == 0 else torch .arange (seq_len_tensor .shape [0 ]).int ()
659
+ query_max_len = input_lens .max () if past_len == 0 else 1
658
660
PagedAttention .flash_attn_varlen_func (
659
661
attn_output ,
660
662
query .contiguous () if query .device .type == "xpu" else query ,
661
663
key_cache .contiguous () if key_cache .device .type == "xpu" else key_cache ,
662
664
value_cache .contiguous () if value_cache .device .type == "xpu" else value_cache ,
665
+ query_len_tensor ,
663
666
seq_len_tensor ,
664
- seq_len_tensor ,
665
- input_lens .max (),
667
+ query_max_len ,
666
668
input_lens .max (),
667
669
1.0 / math .sqrt (self .head_dim ),
668
670
True ,
You can’t perform that action at this time.
0 commit comments