Skip to content

Commit f82fbcd

Browse files
authored
Merge branch 'huggingface:main' into phi
2 parents 3a5b04e + 3ff8dc1 commit f82fbcd

File tree

14 files changed

+219
-786
lines changed

14 files changed

+219
-786
lines changed

.github/workflows/test_generation.yml

-36
This file was deleted.

.github/workflows/test_ipex.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
transformers-version: ["4.47.0", "4.47.1"]
22-
torch-version: ["2.4.0", "2.5.*"]
21+
transformers-version: ["4.47.*"]
22+
torch-version: ["2.6.0"]
2323

2424
runs-on: ubuntu-22.04
2525

examples/neural_compressor/text-generation/README.md

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ limitations under the License.
1818

1919
Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py).
2020

21-
The original generation task only supported the PyTorch eager model. By calling the `TSModelForCausalLM` class, we can now support a TorchScript model for generation tasks.
22-
2321
This example also allows us to apply different quantization approaches (such as dynamic, static, The example applies post-training static quantization on a gptj model).
2422

2523
Example usage:

optimum/exporters/ipex/cache_utils.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from intel_extension_for_pytorch.llm.modules import PagedAttention
66
from transformers import Cache, PretrainedConfig
77

8+
from optimum.intel.utils.import_utils import is_ipex_version
9+
810

911
class IPEXPagedCache(Cache):
1012
"""
@@ -43,10 +45,14 @@ def __init__(
4345
) -> None:
4446
super().__init__()
4547
self.max_batch_size = max_batch_size
48+
self.device = device
49+
self._supports_flash_decoding = (
50+
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
51+
)
4652
# Used in `generate` to keep tally of how many tokens the cache has seen
4753

4854
self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
49-
default_block_size = 16 if device.type == "cpu" else 64
55+
default_block_size = 16
5056
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
5157
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
5258
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
@@ -70,14 +76,44 @@ def __init__(
7076
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7177
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7278
elif device.type == "xpu":
73-
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
74-
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
79+
if self._supports_flash_decoding:
80+
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
81+
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
82+
else:
83+
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
84+
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
7585
for i in range(config.num_hidden_layers):
7686
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
7787
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
7888
self.key_cache.append(new_layer_key_cache)
7989
self.value_cache.append(new_layer_value_cache)
8090

91+
def reshape_and_cache(
92+
self,
93+
key: torch.Tensor,
94+
value: torch.Tensor,
95+
key_cache: torch.Tensor,
96+
value_cache: torch.Tensor,
97+
slots: torch.Tensor,
98+
):
99+
# TODO: unify API definition between CPU and XPU in IPEX version > 2.6
100+
if self.device.type == "xpu" and self._supports_flash_decoding:
101+
PagedAttention.reshape_and_cache_flash(
102+
key,
103+
value,
104+
key_cache,
105+
value_cache,
106+
slots,
107+
)
108+
else:
109+
PagedAttention.reshape_and_cache(
110+
key,
111+
value,
112+
key_cache,
113+
value_cache,
114+
slots,
115+
)
116+
81117
def update_for_prefill(
82118
self,
83119
key_states: torch.Tensor,
@@ -95,7 +131,7 @@ def update_for_prefill(
95131
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
96132
self.block_tables[i][0:nb] = block_table
97133
self.free_blocks[block_table] = 0
98-
slots_range = torch.arange(input_lens[i], device=key_states.device)
134+
slots_range = torch.arange(input_lens[i], device=self.device)
99135
block_indices = slots_range // self.block_size
100136
slot_offsets = slots_range % self.block_size
101137
all_block_indices.append(self.block_tables[i][block_indices])
@@ -105,12 +141,8 @@ def update_for_prefill(
105141
all_slot_offsets = torch.cat(all_slot_offsets)
106142
self.slots = all_block_indices * self.block_size + all_slot_offsets
107143
# Update the cache
108-
PagedAttention.reshape_and_cache(
109-
key_states,
110-
value_states,
111-
self.key_cache[layer_idx],
112-
self.value_cache[layer_idx],
113-
self.slots,
144+
self.reshape_and_cache(
145+
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
114146
)
115147

116148
# Update the number of seen tokens
@@ -128,7 +160,7 @@ def update_for_decode(
128160
if layer_idx == 0:
129161
start_block_idx = self._seen_tokens // self.block_size
130162
slot_offset_in_block = (self._seen_tokens) % self.block_size
131-
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
163+
self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32)
132164
for i in range(batch_size):
133165
if slot_offset_in_block[i] == 0:
134166
# need a new block:
@@ -139,12 +171,8 @@ def update_for_decode(
139171
self.free_blocks[self.block_tables[i][b_idx]] = 0
140172
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
141173
# Update the cache
142-
PagedAttention.reshape_and_cache(
143-
key_states,
144-
value_states,
145-
self.key_cache[layer_idx],
146-
self.value_cache[layer_idx],
147-
self.slots,
174+
self.reshape_and_cache(
175+
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
148176
)
149177

150178
# Update the number of seen tokens
@@ -194,16 +222,15 @@ def get_max_length(self) -> Optional[int]:
194222

195223
def reset(self):
196224
"""Resets the cache values while preserving the objects"""
197-
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.block_tables.device)
225+
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.device)
198226
self.block_tables.fill_(-1)
199-
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
227+
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device)
200228
self.max_seq_len = 0
201229

202230
def reorder_cache(self, beam_idx: torch.LongTensor):
203231
"""Reorders the cache for beam search, given the selected beam indices."""
204-
device = self.block_tables.device
205232
origin_table = self.block_tables.clone()
206-
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
233+
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device))
207234
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
208235
num_blocks = mask.cumsum(-1)[:, -1]
209236
updated_table = torch.zeros_like(beam_idx)

0 commit comments

Comments
 (0)