Skip to content

Commit

Permalink
[Feature]Support transformers==4.48 (#985)
Browse files Browse the repository at this point in the history
* update requirements

* support internlm3, llama, mistral, mixtral, qwen2 and qwen2moe in transformers==4.48
  • Loading branch information
HIT-cwh authored Jan 14, 2025
1 parent 2c06115 commit 4ee8215
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 1,290 deletions.
3 changes: 1 addition & 2 deletions requirements/deepspeed.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# Minimum 0.12.3, see https://github.com/microsoft/DeepSpeed/pull/4587
deepspeed>=0.12.3
deepspeed==0.16.2
mpi4py-mpich
19 changes: 5 additions & 14 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
# Minimum 0.40.0.post4 to fix some 4-bit precision bugs
bitsandbytes>=0.40.0.post4
# Minimum 2.16.0 to fix some bugs, see https://github.com/huggingface/datasets/pull/6444
datasets>=2.16.0
bitsandbytes==0.45.0
datasets>=3.2.0
einops
# Minimum 0.1.2 to fix some bugs, see https://github.com/InternLM/lagent/pull/44
lagent>=0.1.2
# Minimum 0.10.3 to support distributed evaluation for MMBench
# see https://github.com/open-mmlab/mmengine/pull/1469
mmengine>=0.10.3
mmengine==0.10.6
openpyxl
# Minimum 0.4.0 to support QLoRA, see https://github.com/huggingface/peft/pull/476
peft>=0.4.0
peft>=0.14.0
scikit-image
scipy
SentencePiece
tiktoken
torch
torchvision
# Minimum 4.36.0 to support `Cache` data structure used by KV Cache
# Registering a causal mask in `LlamaModel` is not friendly for very large
# `max_position_embeddings`. Refer to
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923
# transformers >= 4.43.0 use _flash_attention_forward but not self._flash_attention_forward
# to calculate attn output which lead to bc braeking
transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2,<=4.42.4
transformers==4.48.0
transformers_stream_generator
93 changes: 35 additions & 58 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,76 +34,75 @@
'possible to return the `attn_weights`.')

LOWEST_TRANSFORMERS_VERSION = dict(
InternLM3ForCausalLM=digit_version('4.48'),
InternLM2ForCausalLM=digit_version('4.36'),
InternLMForCausalLM=digit_version('4.36'),
LlamaForCausalLM=digit_version('4.36'),
LlamaForCausalLM=digit_version('4.48'),
Phi3ForCausalLM=digit_version('4.39'),
MistralForCausalLM=digit_version('4.36'),
MistralForCausalLM=digit_version('4.48'),
# Training mixtral with lower version may lead to nccl timeout
# Refer to https://github.com/microsoft/DeepSpeed/issues/5066
MixtralForCausalLM=digit_version('4.40'),
MixtralForCausalLM=digit_version('4.48'),
CohereForCausalLM=digit_version('4.40'),
Qwen2ForCausalLM=digit_version('4.39'),
Qwen2MoeForCausalLM=digit_version('4.40'),
Qwen2ForCausalLM=digit_version('4.48'),
Qwen2MoeForCausalLM=digit_version('4.48'),
DeepseekV2ForCausalLM=digit_version('4.40'),
)

ATTN_DISPATCH_MAPPING = dict(
InternLM3Attention=LazyObject('xtuner.model.modules.dispatch.internlm3',
'internlm3_attn_forward'),
InternLM2FlashAttention2=LazyObject(
'xtuner.model.modules.dispatch.internlm2', 'internlm2_attn_forward'),
InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm',
'internlm_attn_forward'),
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
'llama_attn_forward'),
LlamaAttention=LazyObject('xtuner.model.modules.dispatch.llama',
'llama_attn_forward'),
Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3',
'phi3_attn_forward'),
MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_attn_forward'),
MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_attn_forward'),
MistralAttention=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_attn_forward'),
MixtralAttention=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_attn_forward'),
CohereFlashAttention2=LazyObject('xtuner.model.modules.dispatch.cohere',
'cohere_attn_forward'),
Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_attn_forward'),
Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_attn_forward'),
Qwen2Attention=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_attn_forward'),
Qwen2MoeAttention=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_attn_forward'),
DeepseekV2FlashAttention2=LazyObject(
'xtuner.model.modules.dispatch.deepseek_v2', 'deepseek_attn_forward'),
)

ATTN_LEGACY_DISPATCH_MAPPING = dict(
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
'llama_attn_forward_legacy'), )

VARLEN_ATTN_DISPATCH_MAPPING = dict(
InternLM3Attention=LazyObject('xtuner.model.modules.dispatch.internlm3',
'internlm3_attn_forward'),
InternLM2FlashAttention2=LazyObject(
'xtuner.model.modules.dispatch.internlm2',
'internlm2_varlen_attn_forward'),
InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm',
'internlm_varlen_attn_forward'),
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
'llama_varlen_attn_forward'),
LlamaAttention=LazyObject('xtuner.model.modules.dispatch.llama',
'llama_attn_forward'),
Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3',
'phi3_varlen_attn_forward'),
MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_varlen_attn_forward'),
MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_varlen_attn_forward'),
MistralAttention=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_attn_forward'),
MixtralAttention=LazyObject('xtuner.model.modules.dispatch.mistral',
'mistral_attn_forward'),
CohereFlashAttention2=None,
Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_varlen_attn_forward'),
Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_varlen_attn_forward'),
Qwen2Attention=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_attn_forward'),
Qwen2MoeAttention=LazyObject('xtuner.model.modules.dispatch.qwen2',
'qwen2_attn_forward'),
DeepseekV2FlashAttention2=LazyObject(
'xtuner.model.modules.dispatch.deepseek_v2',
'deepseek_varlen_attn_forward'),
)

VARLEN_ATTN_LEGACY_DISPATCH_MAPPING = dict(
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
'llama_varlen_attn_forward_legacy'), )

RMS_DISPATCH_MAPPING = dict(
InternLM3RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
'rms_norm_forward'),
InternLM2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
'rms_norm_forward'),
InternLMRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
Expand All @@ -126,12 +125,7 @@

ROTE_DISPATCH_MAPPING = dict(
InternLMRotaryEmbedding=LazyObject(
'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'),
MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
'MistralRotaryEmbedding'),
MixtralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
'MistralRotaryEmbedding'),
)
'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'), )


def log_once(func):
Expand All @@ -158,15 +152,7 @@ def dispatch_attn_forward(model):
attn_forward = None
for module in model.modules():
name = type(module).__name__
if (IS_LOW_VERSION_TRANSFORMERS
and name in ATTN_LEGACY_DISPATCH_MAPPING):
if attn_forward is None:
attn_forward = ATTN_LEGACY_DISPATCH_MAPPING[name]
attn_forward = attn_forward.build()
print_log(f'Dispatch {name} legacy forward. {NO_ATTN_WEIGHTS_MSG}',
'current')
module.forward = types.MethodType(attn_forward, module)
elif name in ATTN_DISPATCH_MAPPING:
if name in ATTN_DISPATCH_MAPPING:
if attn_forward is None:
attn_forward = ATTN_DISPATCH_MAPPING[name]
attn_forward = attn_forward.build()
Expand All @@ -186,16 +172,7 @@ def dispatch_varlen_attn_forward(model):
varlen_attn_forward = None
for module in model.modules():
name = type(module).__name__
if (IS_LOW_VERSION_TRANSFORMERS
and name in VARLEN_ATTN_LEGACY_DISPATCH_MAPPING):
if varlen_attn_forward is None:
varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name]
varlen_attn_forward = varlen_attn_forward.build()
print_log(
f'Dispatch legacy {name} varlen forward. '
f'{NO_ATTN_WEIGHTS_MSG}', 'current')
module.forward = types.MethodType(varlen_attn_forward, module)
elif name in VARLEN_ATTN_DISPATCH_MAPPING:
if name in VARLEN_ATTN_DISPATCH_MAPPING:
if varlen_attn_forward is None:
varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name]
varlen_attn_forward = varlen_attn_forward.build()
Expand Down
132 changes: 132 additions & 0 deletions xtuner/model/modules/dispatch/internlm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Callable, Optional, Tuple

import torch
import torch.distributed as dist
from mmengine import MessageHub
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb,
eager_attention_forward,
repeat_kv)
from transformers.processing_utils import Unpack

from xtuner.parallel.sequence import get_sequence_parallel_world_size
from xtuner.parallel.sequence.attention import (
post_process_for_sequence_parallel_attn,
pre_process_for_sequence_parallel_attn)


def internlm3_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(
1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed
# for the static cache
cache_kwargs = {
'sin': sin,
'cos': cos,
'cache_position': cache_position
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)

# different from LlamaAttention.forward
# repeat k/v heads if n_kv_heads < n_heads for sequence parallel
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

enable_sequence_parallel = (
dist.is_initialized() and get_sequence_parallel_world_size() > 1
and self.training)
if enable_sequence_parallel:
# Reashape for `pre_process_for_sequence_parallel_attn`
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# different places end

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != 'eager':
if self.config._attn_implementation == 'sdpa' and kwargs.get(
'output_attentions', False):
warnings.warn(
'`torch.nn.functional.scaled_dot_product_attention` does not '
'support `output_attentions=True`. Falling back to eager '
'attention. This warning can be removed using the argument'
' `attn_implementation="eager"` when loading the model.')
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation]

message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
use_varlen_atten = (cumulative_len is not None)
if use_varlen_atten:
# When gradient_checkpointing is enabled, the flash_attn_kwargs
# parameter is not automatically passed to the model. In such
# cases, parameters like cu_seq_lens_q and max_length_q are
# computed based on position_ids. However, when sequence
# parallel is enabled, position_ids is split along the
# sequence length, leading to incorrect calculations of these
# parameters.
# To address this issue, it is necessary to manually provide
# the flash_attn_kwargs parameters.
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
kwargs['cu_seq_lens_q'] = cumulative_len
kwargs['cu_seq_lens_k'] = cumulative_len
kwargs['max_length_q'] = max_seqlen
kwargs['max_length_k'] = max_seqlen
kwargs.pop('position_ids', None)

# Hacky: `sdpa_attention_forward` does repeat_kv based on
# module.num_key_value_groups but it is done before
num_key_value_groups = self.num_key_value_groups
self.num_key_value_groups = 1
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
self.num_key_value_groups = num_key_value_groups

# different from LlamaAttention.forward
if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Loading

0 comments on commit 4ee8215

Please sign in to comment.