@@ -46,7 +46,7 @@ def __init__(
46
46
super ().__init__ ()
47
47
self .max_batch_size = max_batch_size
48
48
self .device = device
49
- self .flash_decoding = (
49
+ self ._supports_flash_decoding = (
50
50
is_ipex_version (">" , "2.4.99" ) if device .type == "cpu" else is_ipex_version (">" , "2.5.99" )
51
51
)
52
52
# Used in `generate` to keep tally of how many tokens the cache has seen
@@ -76,7 +76,7 @@ def __init__(
76
76
key_cache_shape = (self .num_blocks , self .num_kv_heads , self .block_size , head_size )
77
77
value_cache_shape = (self .num_blocks , self .num_kv_heads , self .block_size , head_size )
78
78
elif device .type == "xpu" :
79
- if self .flash_decoding :
79
+ if self ._supports_flash_decoding :
80
80
key_cache_shape = (self .num_blocks , self .block_size , self .num_kv_heads , head_size )
81
81
value_cache_shape = (self .num_blocks , self .block_size , self .num_kv_heads , head_size )
82
82
else :
@@ -96,7 +96,8 @@ def reshape_and_cache(
96
96
value_cache : torch .Tensor ,
97
97
slots : torch .Tensor ,
98
98
):
99
- if self .device .type == "xpu" and self .flash_decoding :
99
+ # TODO: unify API definition between CPU and XPU in IPEX version > 2.6
100
+ if self .device .type == "xpu" and self ._supports_flash_decoding :
100
101
PagedAttention .reshape_and_cache_flash (
101
102
key ,
102
103
value ,
0 commit comments