From 12dd802859b7ce6667d343782ac8606d053585ff Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 10:09:30 +0000 Subject: [PATCH 1/4] set default block size Signed-off-by: jiqing-feng --- optimum/exporters/ipex/cache_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 7154c44491..b91da262f2 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -5,6 +5,10 @@ from transformers import Cache, PretrainedConfig +# May need to tune based on sequence length and different models but default to 16 currently. +BLOCK_SIZE = 16 + + class IPEXPagedCache(Cache): """ A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. @@ -44,7 +48,7 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = 64 + self.block_size = BLOCK_SIZE self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1 From c6d2d0f9d140c401761fb67d62b63834c25bbace Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 10:38:14 +0000 Subject: [PATCH 2/4] decoding use single query Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 41dd5693df..4e8de10121 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -652,19 +652,17 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query): + elif self.has_flash_attn(query) and past_len == 0: attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() - query_max_len = input_lens.max() if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, - query_len_tensor, seq_len_tensor, - query_max_len, + seq_len_tensor, + input_lens.max(), input_lens.max(), 1.0 / math.sqrt(self.head_dim), True, From 5b930362210142d954d09de741140cbb8ff448b9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 16 Jan 2025 11:31:31 +0000 Subject: [PATCH 3/4] set block size as a env parameter Signed-off-by: jiqing-feng --- optimum/exporters/ipex/cache_utils.py | 5 +++-- optimum/exporters/ipex/modeling_utils.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index b91da262f2..ded6964831 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple import torch @@ -5,8 +6,8 @@ from transformers import Cache, PretrainedConfig -# May need to tune based on sequence length and different models but default to 16 currently. -BLOCK_SIZE = 16 +# Recommend 16 on CPU and 64 on XPU. +BLOCK_SIZE = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", 16)) class IPEXPagedCache(Cache): diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 4e8de10121..41dd5693df 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -652,17 +652,19 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query) and past_len == 0: + elif self.has_flash_attn(query): attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() + query_max_len = input_lens.max() if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + query_len_tensor, seq_len_tensor, - seq_len_tensor, - input_lens.max(), + query_max_len, input_lens.max(), 1.0 / math.sqrt(self.head_dim), True, From 31accd2134daa78e8e77cfc3721dbef9c865a7d2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 16 Jan 2025 13:29:48 +0000 Subject: [PATCH 4/4] set different default value for block size based on device Signed-off-by: jiqing-feng --- optimum/exporters/ipex/cache_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index ded6964831..f9df2cf69a 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -6,10 +6,6 @@ from transformers import Cache, PretrainedConfig -# Recommend 16 on CPU and 64 on XPU. -BLOCK_SIZE = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", 16)) - - class IPEXPagedCache(Cache): """ A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. @@ -49,7 +45,8 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = BLOCK_SIZE + default_block_size = 16 if device.type == "cpu" else 64 + self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size))) self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1