|
| 1 | +from typing import List, Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +from intel_extension_for_pytorch.llm.modules import PagedAttention |
| 5 | +from transformers import Cache, PretrainedConfig |
| 6 | + |
| 7 | + |
| 8 | +class IPEXPagedCache(Cache): |
| 9 | + """ |
| 10 | + A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. |
| 11 | + ipex-xpu: |
| 12 | + ipex-cpu: |
| 13 | +
|
| 14 | + Example: |
| 15 | +
|
| 16 | + ```python |
| 17 | + >>> from transformers import AutoTokenizer |
| 18 | + >>> from optimum.intel import IPEXModelForCausalLM |
| 19 | + >>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache |
| 20 | +
|
| 21 | + >>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True) |
| 22 | + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
| 23 | +
|
| 24 | + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") |
| 25 | +
|
| 26 | + >>> # Prepare a cache class and pass it to model's forward |
| 27 | + >>> past_key_values = IPEXPagedCache() |
| 28 | + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
| 29 | + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation |
| 30 | + ``` |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + config: PretrainedConfig, |
| 36 | + batch_size: int, |
| 37 | + max_cache_len: int, |
| 38 | + device, |
| 39 | + dtype=None, |
| 40 | + layer_device_map=None, |
| 41 | + **kwargs, |
| 42 | + ) -> None: |
| 43 | + super().__init__() |
| 44 | + self.batch_size = batch_size |
| 45 | + # Used in `generate` to keep tally of how many tokens the cache has seen |
| 46 | + self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) |
| 47 | + self.block_size = 16 |
| 48 | + self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size |
| 49 | + self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( |
| 50 | + batch_size, -1 |
| 51 | + ) |
| 52 | + self.free_blocks = torch.arange(self.num_blocks, device=device) |
| 53 | + self.max_cache_len = max_cache_len |
| 54 | + self.num_kv_heads = config.num_key_value_heads |
| 55 | + self.num_hidden_layers = config.num_hidden_layers |
| 56 | + if hasattr(config, "head_dim"): |
| 57 | + head_size = config.head_dim |
| 58 | + else: |
| 59 | + head_size = config.hidden_size // config.num_attention_heads |
| 60 | + self.head_size = head_size |
| 61 | + self.max_seq_len = 0 |
| 62 | + |
| 63 | + self.key_cache: List[torch.Tensor] = [] |
| 64 | + self.value_cache: List[torch.Tensor] = [] |
| 65 | + |
| 66 | + if device.type == "cpu": |
| 67 | + key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) |
| 68 | + value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) |
| 69 | + elif device.type == "xpu": |
| 70 | + key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) |
| 71 | + value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) |
| 72 | + for i in range(config.num_hidden_layers): |
| 73 | + new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) |
| 74 | + new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) |
| 75 | + self.key_cache.append(new_layer_key_cache) |
| 76 | + self.value_cache.append(new_layer_value_cache) |
| 77 | + |
| 78 | + def update_for_prefill( |
| 79 | + self, |
| 80 | + key_states: torch.Tensor, |
| 81 | + value_states: torch.Tensor, |
| 82 | + layer_idx: int, |
| 83 | + batch_size: int, |
| 84 | + input_lens: torch.Tensor, |
| 85 | + ): |
| 86 | + if layer_idx == 0: |
| 87 | + all_block_indices = [] |
| 88 | + all_slot_offsets = [] |
| 89 | + num_blocks = (input_lens + self.block_size - 1) // self.block_size |
| 90 | + for i in range(batch_size): |
| 91 | + for b_idx in range(num_blocks[i]): |
| 92 | + if self.block_tables[i][b_idx] == -1: |
| 93 | + # need a free block |
| 94 | + self.block_tables[i][b_idx] = self.free_blocks[0] |
| 95 | + self.free_blocks = self.free_blocks[1:] |
| 96 | + |
| 97 | + slots_range = torch.arange(input_lens[i], device=key_states.device) |
| 98 | + block_indices = slots_range // self.block_size |
| 99 | + slot_offsets = slots_range % self.block_size |
| 100 | + all_block_indices.append(self.block_tables[i][block_indices]) |
| 101 | + all_slot_offsets.append(slot_offsets) |
| 102 | + |
| 103 | + all_block_indices = torch.cat(all_block_indices) |
| 104 | + all_slot_offsets = torch.cat(all_slot_offsets) |
| 105 | + self.slots = all_block_indices * self.block_size + all_slot_offsets |
| 106 | + |
| 107 | + # Update the cache |
| 108 | + PagedAttention.reshape_and_cache( |
| 109 | + key_states, |
| 110 | + value_states, |
| 111 | + self.key_cache[layer_idx], |
| 112 | + self.value_cache[layer_idx], |
| 113 | + self.slots, |
| 114 | + ) |
| 115 | + |
| 116 | + # Update the number of seen tokens |
| 117 | + if layer_idx == self.num_hidden_layers - 1: |
| 118 | + self._seen_tokens = self._seen_tokens + input_lens |
| 119 | + self.max_seq_len, _ = self._seen_tokens.max(dim=0) |
| 120 | + |
| 121 | + def update_for_decode( |
| 122 | + self, |
| 123 | + key_states: torch.Tensor, |
| 124 | + value_states: torch.Tensor, |
| 125 | + layer_idx: int, |
| 126 | + batch_size: int, |
| 127 | + ): |
| 128 | + if layer_idx == 0: |
| 129 | + start_block_idx = self._seen_tokens // self.block_size |
| 130 | + num_blocks = (self._seen_tokens + self.block_size) // self.block_size |
| 131 | + slot_offset_in_block = (self._seen_tokens) % self.block_size |
| 132 | + self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32) |
| 133 | + for i in range(batch_size): |
| 134 | + for b_idx in range(start_block_idx[i], num_blocks[i]): |
| 135 | + if self.block_tables[i][b_idx] == -1: |
| 136 | + # need a free block |
| 137 | + self.block_tables[i][b_idx] = self.free_blocks[0] |
| 138 | + self.free_blocks = self.free_blocks[1:] |
| 139 | + |
| 140 | + self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] |
| 141 | + # Update the cache |
| 142 | + PagedAttention.reshape_and_cache( |
| 143 | + key_states, |
| 144 | + value_states, |
| 145 | + self.key_cache[layer_idx], |
| 146 | + self.value_cache[layer_idx], |
| 147 | + self.slots, |
| 148 | + ) |
| 149 | + |
| 150 | + # Update the number of seen tokens |
| 151 | + if layer_idx == self.num_hidden_layers - 1: |
| 152 | + self._seen_tokens = self._seen_tokens + 1 |
| 153 | + self.max_seq_len = self.max_seq_len + 1 |
| 154 | + |
| 155 | + def update( |
| 156 | + self, |
| 157 | + key_states: torch.Tensor, |
| 158 | + value_states: torch.Tensor, |
| 159 | + layer_idx: int, |
| 160 | + attention_mask: torch.Tensor, |
| 161 | + input_lens: torch.Tensor, |
| 162 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 163 | + """ |
| 164 | + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| 165 | +
|
| 166 | + Parameters: |
| 167 | + key_states (`torch.Tensor`): |
| 168 | + The new key states to cache. |
| 169 | + value_states (`torch.Tensor`): |
| 170 | + The new value states to cache. |
| 171 | + layer_idx (`int`): |
| 172 | + The index of the layer to cache the states for. |
| 173 | + Return: |
| 174 | + A tuple containing the updated key and value states. |
| 175 | + """ |
| 176 | + |
| 177 | + batch_size = input_lens.shape[-1] |
| 178 | + if self.get_seq_length() == 0: |
| 179 | + # prefill |
| 180 | + self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens) |
| 181 | + else: |
| 182 | + # decode |
| 183 | + self.update_for_decode(key_states, value_states, layer_idx, batch_size) |
| 184 | + |
| 185 | + return self.key_cache[layer_idx], self.value_cache[layer_idx] |
| 186 | + |
| 187 | + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| 188 | + """Returns the sequence length of the cached states that were seen by the model.""" |
| 189 | + return self.max_seq_len |
| 190 | + |
| 191 | + def get_max_length(self) -> Optional[int]: |
| 192 | + """Returns the maximum sequence length of the cached states.""" |
| 193 | + return self.max_cache_len |
| 194 | + |
| 195 | + def reset(self): |
| 196 | + """Resets the cache values while preserving the objects""" |
| 197 | + self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device) |
| 198 | + self.block_tables.fill_(-1) |
| 199 | + self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device) |
| 200 | + self.max_seq_len = 0 |
| 201 | + |
| 202 | + def reorder_cache(self, beam_idx: torch.LongTensor): |
| 203 | + """Reorders the cache for beam search, given the selected beam indices.""" |
| 204 | + device = self.block_tables.device |
| 205 | + origin_table = self.block_tables.clone() |
| 206 | + updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device)) |
| 207 | + mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0) |
| 208 | + num_blocks = mask.cumsum(-1)[:, -1] |
| 209 | + updated_table = [] |
| 210 | + for i in range(beam_idx.shape[0]): |
| 211 | + self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1] |
| 212 | + updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]]) |
| 213 | + updated_table = torch.cat(tuple(updated_table), dim=0) |
| 214 | + for layer_idx in range(self.num_hidden_layers): |
| 215 | + self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]] |
| 216 | + self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]] |
| 217 | + free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) |
| 218 | + self.free_blocks = torch.cat((self.free_blocks, free_table)) |
| 219 | + |
| 220 | + def crop(self, maximum_length: int): |
| 221 | + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be |
| 222 | + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" |
| 223 | + |
| 224 | + max_seq_len = self.get_seq_length() |
| 225 | + if maximum_length < 0: |
| 226 | + maximum_length = max_seq_len - abs(maximum_length) |
| 227 | + |
| 228 | + if max_seq_len <= maximum_length: |
| 229 | + return |
| 230 | + origin_table = self.block_tables.clone() |
| 231 | + for bs in range(self._seen_tokens.shape[0]): |
| 232 | + new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len |
| 233 | + num_blocks = (new_tokens + self.block_size - 1) // self.block_size |
| 234 | + self.block_tables[bs, num_blocks:] = -1 |
| 235 | + self._seen_tokens[bs] = new_tokens |
| 236 | + self.max_seq_len, _ = self._seen_tokens.max(dim=0) |
| 237 | + free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) |
| 238 | + self.free_blocks = torch.cat((self.free_blocks, free_table)) |
0 commit comments