Skip to content

Commit b792875

Browse files
committed
optimize gpt2 by using linear instead of conv1D
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 6d21075 commit b792875

File tree

2 files changed

+9
-19
lines changed

2 files changed

+9
-19
lines changed

optimum/exporters/ipex/cache_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
self.batch_size = batch_size
4545
# Used in `generate` to keep tally of how many tokens the cache has seen
4646
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
47-
self.block_size = 16
47+
self.block_size = 64
4848
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
4949
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
5050
batch_size, -1

optimum/exporters/ipex/modeling_utils.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -614,22 +614,6 @@ def forward(
614614
if past_len == 0:
615615
# prefill, remove padding
616616
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
617-
# varlen_attention(
618-
# query.contiguous() if query.device.type == "xpu" else query,
619-
# key.contiguous() if key.device.type == "xpu" else key,
620-
# value.contiguous() if value.device.type == "xpu" else value,
621-
# attn_output,
622-
# seq_len_tensor,
623-
# seq_len_tensor,
624-
# input_lens.max(),
625-
# input_lens.max(),
626-
# 0.0,
627-
# 1.0 / math.sqrt(self.head_dim),
628-
# False,
629-
# True,
630-
# False,
631-
# None,
632-
# )
633617
PagedAttention.flash_attn_varlen_func(
634618
attn_output,
635619
query,
@@ -734,9 +718,16 @@ class _IPEXGPT2Attention(_IPEXAttention):
734718
def __init__(self, module, config) -> None:
735719
self.num_key_value_heads = config.num_key_value_heads
736720
super().__init__(module, config)
721+
_setattr_from_module(self, module)
722+
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
723+
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
724+
self.c_attn_linear.bias = self.c_attn.bias
725+
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
726+
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
727+
self.c_proj_linear.bias = self.c_proj.bias
737728

738729
def qkv_gemm(self, hidden_states):
739-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
730+
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
740731
query = query.view(-1, self.num_heads, self.head_dim)
741732
key = key.view(-1, self.num_heads, self.head_dim)
742733
value = value.view(-1, self.num_heads, self.head_dim)
@@ -748,7 +739,6 @@ def rope(self, query, key, *args, **kwargs):
748739
def postprocess_attention_output(self, attn_output):
749740
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
750741
attn_output = self.c_proj(attn_output)
751-
attn_output = self.resid_dropout(attn_output)
752742
return attn_output
753743

754744

0 commit comments

Comments
 (0)