|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import logging as log
|
| 16 | +import math |
16 | 17 | import types
|
17 | 18 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
18 | 19 |
|
@@ -327,9 +328,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
|
327 | 328 | offset = 0
|
328 | 329 | mask_shape = attention_mask.shape
|
329 | 330 | mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
330 |
| - causal_mask[ |
331 |
| - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] |
332 |
| - ] = mask_slice |
| 331 | + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( |
| 332 | + mask_slice |
| 333 | + ) |
333 | 334 |
|
334 | 335 | if (
|
335 | 336 | self.config._attn_implementation == "sdpa"
|
@@ -611,3 +612,132 @@ def __init__(
|
611 | 612 | # model has first inference buffers initialization
|
612 | 613 | if hasattr(self._model.lm_head, "first_flag"):
|
613 | 614 | self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
|
| 615 | + |
| 616 | + |
| 617 | +def _mpt_attention_forward( |
| 618 | + self, |
| 619 | + hidden_states: torch.Tensor, |
| 620 | + position_bias: torch.Tensor, |
| 621 | + past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| 622 | + attention_mask: Optional[torch.Tensor] = None, |
| 623 | +): |
| 624 | + batch_size, seq_length = hidden_states.shape[:2] |
| 625 | + |
| 626 | + mixed_qkv = self.Wqkv(hidden_states) |
| 627 | + query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) |
| 628 | + query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) |
| 629 | + key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) |
| 630 | + value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) |
| 631 | + |
| 632 | + if past_key_value is not None: |
| 633 | + if len(past_key_value) != 0: |
| 634 | + key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| 635 | + value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| 636 | + past_key_value = (key_states, value_states) |
| 637 | + else: |
| 638 | + past_key_value = (key_states, value_states) |
| 639 | + |
| 640 | + attention_mask_sdpa = torch.ones(attention_mask.shape, dtype=query_states.dtype) |
| 641 | + attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min) |
| 642 | + context_states = torch.nn.functional.scaled_dot_product_attention( |
| 643 | + query_states, |
| 644 | + key_states, |
| 645 | + value_states, |
| 646 | + attn_mask=attention_mask_sdpa, |
| 647 | + dropout_p=self.attn_dropout_p, |
| 648 | + scale=self.softmax_scale, |
| 649 | + ) |
| 650 | + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) |
| 651 | + attn_output = self.out_proj(context_states) |
| 652 | + |
| 653 | + return attn_output, None, past_key_value |
| 654 | + |
| 655 | + |
| 656 | +class MPTModelPatcher(DecoderModelPatcher): |
| 657 | + def __enter__(self): |
| 658 | + super().__enter__() |
| 659 | + |
| 660 | + if is_torch_version(">=", "2.1.0"): |
| 661 | + for block in self._model.transformer.blocks: |
| 662 | + block.attn._orig_forward = block.attn.forward |
| 663 | + block.attn.forward = types.MethodType(_mpt_attention_forward, block.attn) |
| 664 | + |
| 665 | + def __exit__(self, exc_type, exc_value, traceback): |
| 666 | + super().__exit__(exc_type, exc_value, traceback) |
| 667 | + for block in self._model.transformer.blocks: |
| 668 | + if hasattr(block.attn, "_orig_forward"): |
| 669 | + block.attn.forward = block.attn._orig_forward |
| 670 | + |
| 671 | + |
| 672 | +def _internlm_attention_forward( |
| 673 | + self, |
| 674 | + hidden_states: torch.Tensor, |
| 675 | + attention_mask: Optional[torch.Tensor] = None, |
| 676 | + position_ids: Optional[torch.LongTensor] = None, |
| 677 | + past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| 678 | + output_attentions: bool = False, |
| 679 | + use_cache: bool = False, |
| 680 | + **kwargs, |
| 681 | +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 682 | + |
| 683 | + from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb |
| 684 | + |
| 685 | + bsz, q_len, _ = hidden_states.size() |
| 686 | + |
| 687 | + qkv_states = self.wqkv(hidden_states) |
| 688 | + |
| 689 | + qkv_states = qkv_states.reshape( |
| 690 | + qkv_states.shape[0], qkv_states.shape[1], -1, 2 + self.num_key_values_groups, self.head_dim |
| 691 | + ) |
| 692 | + query_states = qkv_states[..., : self.num_key_value_groups, :] |
| 693 | + query_states = query_states.reshape(query_states.shape[0], query_states.shape[1], -1, query_states.shape[-1]) |
| 694 | + key_states = qkv_states[..., -2, :] |
| 695 | + value_states = qkv_states[..., -1, :] |
| 696 | + |
| 697 | + query_states = query_states.transpose(1, 2) |
| 698 | + key_states = key_states.transpose(1, 2) |
| 699 | + value_states = value_states.transpose(1, 2) |
| 700 | + |
| 701 | + kv_seq_len = key_states.shape[-2] |
| 702 | + if past_key_value is not None: |
| 703 | + kv_seq_len += past_key_value[0].shape[-2] |
| 704 | + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 705 | + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
| 706 | + |
| 707 | + if past_key_value is not None: |
| 708 | + # reuse k, v, self_attention |
| 709 | + key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| 710 | + value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| 711 | + |
| 712 | + past_key_value = (key_states, value_states) if use_cache else None |
| 713 | + |
| 714 | + key_states = repeat_kv(key_states, self.num_key_value_groups) |
| 715 | + value_states = repeat_kv(value_states, self.num_key_value_groups) |
| 716 | + |
| 717 | + attn_output = torch.nn.functional.scaled_dot_product_attention( |
| 718 | + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) |
| 719 | + ) |
| 720 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 721 | + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
| 722 | + |
| 723 | + attn_output = self.wo(attn_output) |
| 724 | + |
| 725 | + attn_weights = None |
| 726 | + |
| 727 | + return attn_output, attn_weights, past_key_value |
| 728 | + |
| 729 | + |
| 730 | +class InternLMPatcher(DecoderModelPatcher): |
| 731 | + def __enter__(self): |
| 732 | + super().__enter__() |
| 733 | + |
| 734 | + if is_torch_version(">=", "2.1.0"): |
| 735 | + for block in self._model.model.layers: |
| 736 | + block.attention._orig_forward = block.attention.forward |
| 737 | + block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention) |
| 738 | + |
| 739 | + def __exit__(self, exc_type, exc_value, traceback): |
| 740 | + super().__exit__(exc_type, exc_value, traceback) |
| 741 | + for block in self._model.model.layers: |
| 742 | + if hasattr(block.attention, "_orig_forward"): |
| 743 | + block.attention.forward = block.attention._orig_forward |
0 commit comments