Skip to content

Commit 8dacb0a

Browse files
committed
optimize the performance
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 850195e commit 8dacb0a

File tree

1 file changed

+50
-11
lines changed

1 file changed

+50
-11
lines changed

optimum/exporters/ipex/modeling_utils.py

+50-11
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def _llama_model_forward(
179179

180180
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
181181

182+
device = input_ids.device if input_ids is not None else inputs_embeds.device
182183
if position_ids is None:
183-
device = input_ids.device if input_ids is not None else inputs_embeds.device
184184
position_ids = torch.arange(
185185
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
186186
)
@@ -200,6 +200,9 @@ def _llama_model_forward(
200200
position_embeddings = self.rotary_emb(hidden_states, position_ids)
201201

202202
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
203+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
204+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
205+
max_input_lens = input_lens.max().item()
203206

204207
if past_key_values_length == 0 and past_key_values is not None:
205208
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -235,6 +238,9 @@ def _llama_model_forward(
235238
use_cache=use_cache,
236239
position_embeddings=position_embeddings,
237240
input_lens=input_lens,
241+
max_input_lens=max_input_lens,
242+
seq_len_tensor=seq_len_tensor,
243+
query_len_tensor=query_len_tensor,
238244
)
239245

240246
hidden_states = layer_outputs[0]
@@ -303,10 +309,11 @@ def _falcon_model_forward(
303309

304310
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
305311
batch_size, seq_length, _ = inputs_embeds.shape
312+
device = input_ids.device if input_ids is not None else inputs_embeds.device
306313

307314
if cache_position is None:
308315
cache_position = torch.arange(
309-
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
316+
past_key_values_length, past_key_values_length + seq_length, device=device
310317
)
311318

312319
if position_ids is None:
@@ -323,6 +330,9 @@ def _falcon_model_forward(
323330
position_embeddings = self.rotary_emb(hidden_states, position_ids)
324331

325332
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
333+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
334+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
335+
max_input_lens = input_lens.max().item()
326336

327337
if past_key_values_length == 0 and past_key_values is not None:
328338
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -365,6 +375,9 @@ def _falcon_model_forward(
365375
cache_position=cache_position,
366376
position_embeddings=position_embeddings,
367377
input_lens=input_lens,
378+
max_input_lens=max_input_lens,
379+
seq_len_tensor=seq_len_tensor,
380+
query_len_tensor=query_len_tensor,
368381
)
369382

370383
hidden_states = outputs[0]
@@ -459,6 +472,9 @@ def _gpt2_model_forward(
459472
hidden_states = self.drop(hidden_states)
460473

461474
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
475+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
476+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
477+
max_input_lens = input_lens.max().item()
462478

463479
if past_length == 0 and past_key_values is not None:
464480
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -494,6 +510,9 @@ def _gpt2_model_forward(
494510
use_cache=use_cache,
495511
output_attentions=output_attentions,
496512
input_lens=input_lens,
513+
max_input_lens=max_input_lens,
514+
seq_len_tensor=seq_len_tensor,
515+
query_len_tensor=query_len_tensor,
497516
)
498517

499518
hidden_states = outputs[0]
@@ -635,7 +654,19 @@ def has_flash_attn(self):
635654
return is_torch_version(">", "2.5.99")
636655

637656
def attention_interface(
638-
self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
657+
self,
658+
query,
659+
key_cache,
660+
value_cache,
661+
key,
662+
value,
663+
past_key_value,
664+
attention_mask,
665+
input_lens,
666+
past_len,
667+
seq_len_tensor,
668+
query_len_tensor,
669+
max_input_lens,
639670
):
640671
if past_key_value is None:
641672
n_rep = query.shape[1] // key.shape[1]
@@ -654,18 +685,13 @@ def attention_interface(
654685
self.use_sdpa = True
655686
elif self.has_flash_attn():
656687
attn_output = torch.empty_like(query)
657-
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
658-
query_len_tensor = (
659-
seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0], device=query.device).int()
660-
)
661-
max_input_lens = input_lens.max().item()
662688
query_max_len = max_input_lens if past_len == 0 else 1
663689
PagedAttention.flash_attn_varlen_func(
664690
attn_output,
665691
query.contiguous() if query.device.type == "xpu" else query,
666692
key_cache,
667693
value_cache,
668-
query_len_tensor,
694+
seq_len_tensor if past_len == 0 else query_len_tensor,
669695
seq_len_tensor,
670696
query_max_len,
671697
max_input_lens,
@@ -677,7 +703,6 @@ def attention_interface(
677703
elif past_len == 0:
678704
# prefill, remove padding
679705
attn_output = torch.empty_like(query)
680-
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
681706
varlen_attention(
682707
query.contiguous() if query.device.type == "xpu" else query,
683708
key.contiguous() if key.device.type == "xpu" else key,
@@ -726,6 +751,9 @@ def forward(
726751
if past_key_value is None and kwargs.get("layer_past", None) is not None:
727752
past_key_value = kwargs.pop("layer_past", None)
728753
input_lens = kwargs.pop("input_lens", None)
754+
seq_len_tensor = kwargs.pop("seq_len_tensor", None)
755+
query_len_tensor = kwargs.pop("query_len_tensor", None)
756+
max_input_lens = kwargs.pop("max_input_lens", 0)
729757
past_len = 0
730758
if past_key_value is not None:
731759
past_len = past_key_value.get_seq_length()
@@ -737,7 +765,18 @@ def forward(
737765
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens)
738766

739767
attn_output = self.attention_interface(
740-
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
768+
query,
769+
key_cache,
770+
value_cache,
771+
key,
772+
value,
773+
past_key_value,
774+
attention_mask,
775+
input_lens,
776+
past_len,
777+
seq_len_tensor,
778+
query_len_tensor,
779+
max_input_lens,
741780
)
742781

743782
attn_output = self.postprocess_attention_output(attn_output)

0 commit comments

Comments
 (0)