Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable qwen2 model #1107

Merged
merged 41 commits into from
Feb 11, 2025
Merged
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6d21075
use real varlen attn
jiqing-feng Dec 11, 2024
b792875
optimize gpt2 by using linear instead of conv1D
jiqing-feng Dec 12, 2024
422134f
Merge branch 'huggingface:main' into varlen
jiqing-feng Dec 12, 2024
36884cb
fix usage without pkv
jiqing-feng Dec 12, 2024
d061e69
use sdpa for no cache forward
jiqing-feng Dec 12, 2024
31c635a
fix format
jiqing-feng Dec 12, 2024
73a5ef7
fix sdpa
jiqing-feng Dec 12, 2024
f9c021b
revert shape for sdpa
jiqing-feng Dec 12, 2024
d069407
fix sdpa precision, still have error
jiqing-feng Dec 12, 2024
2c54045
fix sdpa shape
jiqing-feng Dec 13, 2024
bce9aa9
upgrad minimum torch version to 2.5
jiqing-feng Dec 13, 2024
72ac9e6
rm pdb
jiqing-feng Dec 13, 2024
3fdb3a5
fix non patch path
jiqing-feng Dec 16, 2024
7e20b86
Merge branch 'main' into varlen
jiqing-feng Dec 18, 2024
c1bd7f7
Merge branch 'huggingface:main' into varlen
jiqing-feng Dec 25, 2024
fb71c2e
Merge branch 'huggingface:main' into varlen
jiqing-feng Jan 13, 2025
6186aaf
use varlen if flash attn not available
jiqing-feng Jan 14, 2025
cbc232b
revert ipex version change
jiqing-feng Jan 14, 2025
4dd2e44
fix flash attn check
jiqing-feng Jan 14, 2025
372d3f8
prefill attn
jiqing-feng Jan 14, 2025
daddabf
fix cache
jiqing-feng Jan 14, 2025
8e8c95f
qwen2 model forward
jiqing-feng Jan 14, 2025
95b7043
refactor attention
jiqing-feng Jan 14, 2025
71aa6b0
use flash attn for decode
jiqing-feng Jan 14, 2025
9211803
fix dtype
jiqing-feng Jan 14, 2025
333bd86
Merge branch 'varlen' into qwen
jiqing-feng Jan 14, 2025
d3fbd65
enable qwen2 model
jiqing-feng Jan 14, 2025
06798e2
enable qwen2 test
jiqing-feng Jan 14, 2025
12dd802
set default block size
jiqing-feng Jan 15, 2025
c6d2d0f
decoding use single query
jiqing-feng Jan 15, 2025
00e6bf3
rebase
jiqing-feng Jan 15, 2025
acfd0ce
fix position_id init for qwen2
jiqing-feng Jan 15, 2025
ccbe97a
add patched qwen2 test
jiqing-feng Jan 15, 2025
ee7dd81
fix format
jiqing-feng Jan 15, 2025
c86fd1c
fix pipeline test
jiqing-feng Jan 15, 2025
5b93036
set block size as a env parameter
jiqing-feng Jan 16, 2025
31accd2
set different default value for block size based on device
jiqing-feng Jan 16, 2025
e75b45b
Merge branch 'block_size' into qwen
jiqing-feng Jan 16, 2025
8656c26
Merge branch 'huggingface:main' into qwen
jiqing-feng Jan 17, 2025
4ddc352
Merge branch 'huggingface:main' into qwen
jiqing-feng Jan 22, 2025
59f381c
change new prompt
jiqing-feng Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
optimize gpt2 by using linear instead of conv1D
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
jiqing-feng committed Dec 12, 2024
commit b792875be1e2ae97275afc0fe53f28f6b202190d
2 changes: 1 addition & 1 deletion optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ def __init__(
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.block_size = 64
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
26 changes: 8 additions & 18 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -614,22 +614,6 @@ def forward(
if past_len == 0:
# prefill, remove padding
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
# varlen_attention(
# query.contiguous() if query.device.type == "xpu" else query,
# key.contiguous() if key.device.type == "xpu" else key,
# value.contiguous() if value.device.type == "xpu" else value,
# attn_output,
# seq_len_tensor,
# seq_len_tensor,
# input_lens.max(),
# input_lens.max(),
# 0.0,
# 1.0 / math.sqrt(self.head_dim),
# False,
# True,
# False,
# None,
# )
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
@@ -734,9 +718,16 @@ class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, config)
_setattr_from_module(self, module)
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias

def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_heads, self.head_dim)
key = key.view(-1, self.num_heads, self.head_dim)
value = value.view(-1, self.num_heads, self.head_dim)
@@ -748,7 +739,6 @@ def rope(self, query, key, *args, **kwargs):
def postprocess_attention_output(self, attn_output):
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output