-
Notifications
You must be signed in to change notification settings - Fork 331
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature]Support transformers==4.48 (#985)
* update requirements * support internlm3, llama, mistral, mixtral, qwen2 and qwen2moe in transformers==4.48
- Loading branch information
Showing
8 changed files
with
446 additions
and
1,290 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
# Minimum 0.12.3, see https://github.com/microsoft/DeepSpeed/pull/4587 | ||
deepspeed>=0.12.3 | ||
deepspeed==0.16.2 | ||
mpi4py-mpich |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,18 @@ | ||
# Minimum 0.40.0.post4 to fix some 4-bit precision bugs | ||
bitsandbytes>=0.40.0.post4 | ||
# Minimum 2.16.0 to fix some bugs, see https://github.com/huggingface/datasets/pull/6444 | ||
datasets>=2.16.0 | ||
bitsandbytes==0.45.0 | ||
datasets>=3.2.0 | ||
einops | ||
# Minimum 0.1.2 to fix some bugs, see https://github.com/InternLM/lagent/pull/44 | ||
lagent>=0.1.2 | ||
# Minimum 0.10.3 to support distributed evaluation for MMBench | ||
# see https://github.com/open-mmlab/mmengine/pull/1469 | ||
mmengine>=0.10.3 | ||
mmengine==0.10.6 | ||
openpyxl | ||
# Minimum 0.4.0 to support QLoRA, see https://github.com/huggingface/peft/pull/476 | ||
peft>=0.4.0 | ||
peft>=0.14.0 | ||
scikit-image | ||
scipy | ||
SentencePiece | ||
tiktoken | ||
torch | ||
torchvision | ||
# Minimum 4.36.0 to support `Cache` data structure used by KV Cache | ||
# Registering a causal mask in `LlamaModel` is not friendly for very large | ||
# `max_position_embeddings`. Refer to | ||
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923 | ||
# transformers >= 4.43.0 use _flash_attention_forward but not self._flash_attention_forward | ||
# to calculate attn output which lead to bc braeking | ||
transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2,<=4.42.4 | ||
transformers==4.48.0 | ||
transformers_stream_generator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
from typing import Callable, Optional, Tuple | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from mmengine import MessageHub | ||
from transformers.cache_utils import Cache | ||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs | ||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | ||
from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb, | ||
eager_attention_forward, | ||
repeat_kv) | ||
from transformers.processing_utils import Unpack | ||
|
||
from xtuner.parallel.sequence import get_sequence_parallel_world_size | ||
from xtuner.parallel.sequence.attention import ( | ||
post_process_for_sequence_parallel_attn, | ||
pre_process_for_sequence_parallel_attn) | ||
|
||
|
||
def internlm3_attn_forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
position_embeddings: Tuple[torch.Tensor, torch.Tensor], | ||
attention_mask: Optional[torch.Tensor], | ||
past_key_value: Optional[Cache] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
**kwargs: Unpack[FlashAttentionKwargs], | ||
): | ||
input_shape = hidden_states.shape[:-1] | ||
hidden_shape = (*input_shape, -1, self.head_dim) | ||
|
||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose( | ||
1, 2) | ||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose( | ||
1, 2) | ||
|
||
cos, sin = position_embeddings | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | ||
cos, sin) | ||
|
||
if past_key_value is not None: | ||
# sin and cos are specific to RoPE models; cache_position needed | ||
# for the static cache | ||
cache_kwargs = { | ||
'sin': sin, | ||
'cos': cos, | ||
'cache_position': cache_position | ||
} | ||
key_states, value_states = past_key_value.update( | ||
key_states, value_states, self.layer_idx, cache_kwargs) | ||
|
||
# different from LlamaAttention.forward | ||
# repeat k/v heads if n_kv_heads < n_heads for sequence parallel | ||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
||
enable_sequence_parallel = ( | ||
dist.is_initialized() and get_sequence_parallel_world_size() > 1 | ||
and self.training) | ||
if enable_sequence_parallel: | ||
# Reashape for `pre_process_for_sequence_parallel_attn` | ||
query_states = query_states.transpose(1, 2) | ||
key_states = key_states.transpose(1, 2) | ||
value_states = value_states.transpose(1, 2) | ||
query_states, key_states, value_states = \ | ||
pre_process_for_sequence_parallel_attn( | ||
query_states, key_states, value_states) | ||
query_states = query_states.transpose(1, 2) | ||
key_states = key_states.transpose(1, 2) | ||
value_states = value_states.transpose(1, 2) | ||
# different places end | ||
|
||
attention_interface: Callable = eager_attention_forward | ||
if self.config._attn_implementation != 'eager': | ||
if self.config._attn_implementation == 'sdpa' and kwargs.get( | ||
'output_attentions', False): | ||
warnings.warn( | ||
'`torch.nn.functional.scaled_dot_product_attention` does not ' | ||
'support `output_attentions=True`. Falling back to eager ' | ||
'attention. This warning can be removed using the argument' | ||
' `attn_implementation="eager"` when loading the model.') | ||
else: | ||
attention_interface = ALL_ATTENTION_FUNCTIONS[ | ||
self.config._attn_implementation] | ||
|
||
message_hub = MessageHub.get_instance('varlen_attn_args') | ||
rank = dist.get_rank() | ||
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') | ||
use_varlen_atten = (cumulative_len is not None) | ||
if use_varlen_atten: | ||
# When gradient_checkpointing is enabled, the flash_attn_kwargs | ||
# parameter is not automatically passed to the model. In such | ||
# cases, parameters like cu_seq_lens_q and max_length_q are | ||
# computed based on position_ids. However, when sequence | ||
# parallel is enabled, position_ids is split along the | ||
# sequence length, leading to incorrect calculations of these | ||
# parameters. | ||
# To address this issue, it is necessary to manually provide | ||
# the flash_attn_kwargs parameters. | ||
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') | ||
kwargs['cu_seq_lens_q'] = cumulative_len | ||
kwargs['cu_seq_lens_k'] = cumulative_len | ||
kwargs['max_length_q'] = max_seqlen | ||
kwargs['max_length_k'] = max_seqlen | ||
kwargs.pop('position_ids', None) | ||
|
||
# Hacky: `sdpa_attention_forward` does repeat_kv based on | ||
# module.num_key_value_groups but it is done before | ||
num_key_value_groups = self.num_key_value_groups | ||
self.num_key_value_groups = 1 | ||
attn_output, attn_weights = attention_interface( | ||
self, | ||
query_states, | ||
key_states, | ||
value_states, | ||
attention_mask, | ||
dropout=0.0 if not self.training else self.attention_dropout, | ||
scaling=self.scaling, | ||
**kwargs, | ||
) | ||
self.num_key_value_groups = num_key_value_groups | ||
|
||
# different from LlamaAttention.forward | ||
if enable_sequence_parallel: | ||
attn_output = post_process_for_sequence_parallel_attn(attn_output) | ||
|
||
attn_output = attn_output.reshape(*input_shape, -1).contiguous() | ||
attn_output = self.o_proj(attn_output) | ||
return attn_output, attn_weights |
Oops, something went wrong.