Skip to content

Commit 3616dd2

Browse files
committed
refine code
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 23a3b54 commit 3616dd2

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

optimum/exporters/ipex/cache_utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
super().__init__()
4747
self.max_batch_size = max_batch_size
4848
self.device = device
49-
self.flash_decoding = (
49+
self._supports_flash_decoding = (
5050
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
5151
)
5252
# Used in `generate` to keep tally of how many tokens the cache has seen
@@ -76,7 +76,7 @@ def __init__(
7676
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7777
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7878
elif device.type == "xpu":
79-
if self.flash_decoding:
79+
if self._supports_flash_decoding:
8080
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
8181
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
8282
else:
@@ -96,7 +96,8 @@ def reshape_and_cache(
9696
value_cache: torch.Tensor,
9797
slots: torch.Tensor,
9898
):
99-
if self.device.type == "xpu" and self.flash_decoding:
99+
# TODO: unify API definition between CPU and XPU in IPEX version > 2.6
100+
if self.device.type == "xpu" and self._supports_flash_decoding:
100101
PagedAttention.reshape_and_cache_flash(
101102
key,
102103
value,

0 commit comments

Comments
 (0)