diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 6e99e08aaf..7b8ab1cc7f 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -5,6 +5,8 @@ from intel_extension_for_pytorch.llm.modules import PagedAttention from transformers import Cache, PretrainedConfig +from optimum.intel.utils.import_utils import is_ipex_version + class IPEXPagedCache(Cache): """ @@ -43,10 +45,14 @@ def __init__( ) -> None: super().__init__() self.max_batch_size = max_batch_size + self.device = device + self._supports_flash_decoding = ( + is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99") + ) # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device) - default_block_size = 16 if device.type == "cpu" else 64 + default_block_size = 16 self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size))) self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( @@ -70,14 +76,44 @@ def __init__( key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) elif device.type == "xpu": - key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) - value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) + if self._supports_flash_decoding: + key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) + value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) + else: + key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) + value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) for i in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) + def reshape_and_cache( + self, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, + ): + # TODO: unify API definition between CPU and XPU in IPEX version > 2.6 + if self.device.type == "xpu" and self._supports_flash_decoding: + PagedAttention.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slots, + ) + else: + PagedAttention.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + ) + def update_for_prefill( self, key_states: torch.Tensor, @@ -95,7 +131,7 @@ def update_for_prefill( block_table = self.free_blocks.nonzero().view(-1)[0:nb] self.block_tables[i][0:nb] = block_table self.free_blocks[block_table] = 0 - slots_range = torch.arange(input_lens[i], device=key_states.device) + slots_range = torch.arange(input_lens[i], device=self.device) block_indices = slots_range // self.block_size slot_offsets = slots_range % self.block_size all_block_indices.append(self.block_tables[i][block_indices]) @@ -105,12 +141,8 @@ def update_for_prefill( all_slot_offsets = torch.cat(all_slot_offsets) self.slots = all_block_indices * self.block_size + all_slot_offsets # Update the cache - PagedAttention.reshape_and_cache( - key_states, - value_states, - self.key_cache[layer_idx], - self.value_cache[layer_idx], - self.slots, + self.reshape_and_cache( + key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots ) # Update the number of seen tokens @@ -128,7 +160,7 @@ def update_for_decode( if layer_idx == 0: start_block_idx = self._seen_tokens // self.block_size slot_offset_in_block = (self._seen_tokens) % self.block_size - self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32) + self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32) for i in range(batch_size): if slot_offset_in_block[i] == 0: # need a new block: @@ -139,12 +171,8 @@ def update_for_decode( self.free_blocks[self.block_tables[i][b_idx]] = 0 self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] # Update the cache - PagedAttention.reshape_and_cache( - key_states, - value_states, - self.key_cache[layer_idx], - self.value_cache[layer_idx], - self.slots, + self.reshape_and_cache( + key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots ) # Update the number of seen tokens @@ -194,16 +222,15 @@ def get_max_length(self) -> Optional[int]: def reset(self): """Resets the cache values while preserving the objects""" - self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.block_tables.device) + self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.device) self.block_tables.fill_(-1) - self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device) + self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device) self.max_seq_len = 0 def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" - device = self.block_tables.device origin_table = self.block_tables.clone() - updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device)) + updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device)) mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0) num_blocks = mask.cumsum(-1)[:, -1] updated_table = torch.zeros_like(beam_idx) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 2b440aa91a..3ff807c17e 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -179,8 +179,8 @@ def _llama_model_forward( past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + device = input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) @@ -200,6 +200,9 @@ def _llama_model_forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -235,6 +238,9 @@ def _llama_model_forward( use_cache=use_cache, position_embeddings=position_embeddings, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, ) hidden_states = layer_outputs[0] @@ -303,11 +309,10 @@ def _falcon_model_forward( past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 batch_size, seq_length, _ = inputs_embeds.shape + device = input_ids.device if input_ids is not None else inputs_embeds.device if cache_position is None: - cache_position = torch.arange( - past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device - ) + cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device) if position_ids is None: position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) @@ -323,6 +328,9 @@ def _falcon_model_forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -365,6 +373,9 @@ def _falcon_model_forward( cache_position=cache_position, position_embeddings=position_embeddings, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, ) hidden_states = outputs[0] @@ -459,6 +470,9 @@ def _gpt2_model_forward( hidden_states = self.drop(hidden_states) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -494,6 +508,9 @@ def _gpt2_model_forward( use_cache=use_cache, output_attentions=output_attentions, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, ) hidden_states = outputs[0] @@ -636,6 +653,7 @@ def _qwen2_model_forward( inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = inputs_embeds.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: @@ -660,6 +678,9 @@ def _qwen2_model_forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -695,6 +716,9 @@ def _qwen2_model_forward( cache_position=cache_position, position_embeddings=position_embeddings, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, **kwargs, ) @@ -749,14 +773,26 @@ def postprocess_attention_output(self, attn_output): return attn_output # Maybe removed after torch 2.6 released - def has_flash_attn(self, query): - if query.device.type == "cpu": + def has_flash_attn(self): + if self.module_device.type == "cpu": return is_torch_version(">", "2.4.99") - elif query.device.type == "xpu": + elif self.module_device.type == "xpu": return is_torch_version(">", "2.5.99") def attention_interface( - self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + self, + query, + key_cache, + value_cache, + key, + value, + past_key_value, + attention_mask, + input_lens, + past_len, + seq_len_tensor, + query_len_tensor, + max_input_lens, ): if past_key_value is None: n_rep = query.shape[1] // key.shape[1] @@ -773,20 +809,19 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query): + elif self.has_flash_attn(): attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() - query_max_len = input_lens.max() if past_len == 0 else 1 + query_len_tensor = seq_len_tensor if past_len == 0 else query_len_tensor + query_max_len = max_input_lens if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, - key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, - value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + key_cache, + value_cache, query_len_tensor, seq_len_tensor, query_max_len, - input_lens.max(), + max_input_lens, 1.0 / math.sqrt(self.head_dim), True, past_key_value.block_tables, @@ -795,7 +830,6 @@ def attention_interface( elif past_len == 0: # prefill, remove padding attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) varlen_attention( query.contiguous() if query.device.type == "xpu" else query, key.contiguous() if key.device.type == "xpu" else key, @@ -844,6 +878,9 @@ def forward( if past_key_value is None and kwargs.get("layer_past", None) is not None: past_key_value = kwargs.pop("layer_past", None) input_lens = kwargs.pop("input_lens", None) + seq_len_tensor = kwargs.pop("seq_len_tensor", None) + query_len_tensor = kwargs.pop("query_len_tensor", None) + max_input_lens = kwargs.pop("max_input_lens", 0) past_len = 0 if past_key_value is not None: past_len = past_key_value.get_seq_length() @@ -855,7 +892,18 @@ def forward( key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) attn_output = self.attention_interface( - query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + query, + key_cache, + value_cache, + key, + value, + past_key_value, + attention_mask, + input_lens, + past_len, + seq_len_tensor, + query_len_tensor, + max_input_lens, ) attn_output = self.postprocess_attention_output(attn_output) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index eda4f09614..dc15d161c8 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -291,7 +291,7 @@ def test_forward(self, model_arch): dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) self.assertIsInstance(ipex_model.config, PretrainedConfig) - input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long) + input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long).to(DEVICE) outputs = ipex_model(input_ids) self.assertIsInstance(outputs.logits, torch.Tensor)