Skip to content

Commit 7b4044d

Browse files
authored
set paged attn block size as a env parameter (#1109)
* set default block size Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * decoding use single query Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set block size as a env parameter Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set different default value for block size based on device Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 726191f commit 7b4044d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

optimum/exporters/ipex/cache_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import List, Optional, Tuple
23

34
import torch
@@ -44,7 +45,8 @@ def __init__(
4445
self.batch_size = batch_size
4546
# Used in `generate` to keep tally of how many tokens the cache has seen
4647
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
47-
self.block_size = 64
48+
default_block_size = 16 if device.type == "cpu" else 64
49+
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
4850
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
4951
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
5052
batch_size, -1

0 commit comments

Comments
 (0)