Skip to content

Commit dd4fe68

Browse files
authoredFeb 11, 2025
enable qwen2 model (#1107)
* use real varlen attn Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * optimize gpt2 by using linear instead of conv1D Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix usage without pkv Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * use sdpa for no cache forward Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix sdpa Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert shape for sdpa Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix sdpa precision, still have error Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix sdpa shape Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * upgrad minimum torch version to 2.5 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm pdb Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix non patch path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * use varlen if flash attn not available Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert ipex version change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix flash attn check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * prefill attn Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * qwen2 model forward Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * refactor attention Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * use flash attn for decode Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable qwen2 model Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable qwen2 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set default block size Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * decoding use single query Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix position_id init for qwen2 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add patched qwen2 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix pipeline test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set block size as a env parameter Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set different default value for block size based on device Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change new prompt Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 3befef7 commit dd4fe68

File tree

6 files changed

+163
-22
lines changed

6 files changed

+163
-22
lines changed
 

‎optimum/exporters/ipex/model_patcher.py

+21
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
LlamaModel,
2121
LlamaRMSNorm,
2222
)
23+
from transformers.models.qwen2.modeling_qwen2 import (
24+
Qwen2DecoderLayer,
25+
Qwen2Model,
26+
Qwen2RMSNorm,
27+
)
2328
from transformers.models.vit.modeling_vit import ViTIntermediate
2429

2530
from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version
@@ -36,7 +41,9 @@
3641
_IPEXGPT2Attention,
3742
_IPEXIntermediate,
3843
_IPEXLlamaDecoderLayer,
44+
_IPEXQwen2DecoderLayer,
3945
_llama_model_forward,
46+
_qwen2_model_forward,
4047
)
4148

4249

@@ -116,6 +123,18 @@ def _patch_gpt2_model(model):
116123
return model
117124

118125

126+
def _patch_qwen2_model(model):
127+
"""
128+
Patch qwen2 model:
129+
1. Use IPEX rope and paged cache
130+
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
131+
"""
132+
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward)
133+
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward)
134+
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config)
135+
return model
136+
137+
119138
def _patch_bert_model(model):
120139
"""
121140
Patch bert model:
@@ -149,6 +168,8 @@ def _patch_model(model):
149168
model = _patch_falcon_model(model)
150169
elif model.config.model_type == "gpt2":
151170
model = _patch_gpt2_model(model)
171+
elif model.config.model_type == "qwen2":
172+
model = _patch_qwen2_model(model)
152173
elif model.config.model_type == "bert":
153174
model = _patch_bert_model(model)
154175
elif model.config.model_type == "vit":

‎optimum/exporters/ipex/modeling_utils.py

+131-14
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,125 @@ def _gpt2_block_forward(
603603
return outputs # hidden_states, present, (attentions, cross_attentions)
604604

605605

606+
# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499
607+
def _qwen2_model_forward(
608+
self,
609+
input_ids: torch.LongTensor = None,
610+
attention_mask: Optional[torch.Tensor] = None,
611+
position_ids: Optional[torch.LongTensor] = None,
612+
past_key_values: Optional[Cache] = None,
613+
inputs_embeds: Optional[torch.FloatTensor] = None,
614+
use_cache: Optional[bool] = None,
615+
output_attentions: Optional[bool] = None,
616+
output_hidden_states: Optional[bool] = None,
617+
return_dict: Optional[bool] = None,
618+
cache_position: Optional[torch.LongTensor] = None,
619+
**kwargs,
620+
) -> Union[Tuple, BaseModelOutputWithPast]:
621+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
622+
output_hidden_states = (
623+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
624+
)
625+
use_cache = use_cache if use_cache is not None else self.config.use_cache
626+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
627+
628+
if (input_ids is None) ^ (inputs_embeds is not None):
629+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
630+
631+
if self.gradient_checkpointing and self.training and use_cache:
632+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
633+
use_cache = False
634+
635+
if inputs_embeds is None:
636+
inputs_embeds = self.embed_tokens(input_ids)
637+
638+
batch_size, seq_length = inputs_embeds.shape[:2]
639+
640+
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
641+
if cache_position is None:
642+
cache_position = torch.arange(
643+
past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
644+
)
645+
646+
if position_ids is None:
647+
device = input_ids.device if input_ids is not None else inputs_embeds.device
648+
position_ids = torch.arange(
649+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
650+
)
651+
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
652+
653+
causal_mask = self._update_causal_mask(
654+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
655+
)
656+
657+
hidden_states = inputs_embeds
658+
659+
# create position embeddings to be shared across the decoder layers
660+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
661+
662+
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
663+
664+
if past_key_values_length == 0 and past_key_values is not None:
665+
# first token, remove the padding from hidden_states, varlen do not accept attention mask
666+
hidden_states_copy = hidden_states
667+
index = attention_mask.view(-1) != 0
668+
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
669+
cos = position_embeddings[0]
670+
sin = position_embeddings[1]
671+
cos = (cos.reshape(-1, cos.shape[-1]))[index]
672+
sin = (sin.reshape(-1, sin.shape[-1]))[index]
673+
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
674+
else:
675+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
676+
677+
if past_key_values is None:
678+
attention_mask = causal_mask
679+
680+
# decoder layers
681+
all_hidden_states = () if output_hidden_states else None
682+
all_self_attns = () if output_attentions else None
683+
684+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
685+
if output_hidden_states:
686+
all_hidden_states += (hidden_states,)
687+
688+
layer_outputs = decoder_layer(
689+
hidden_states,
690+
attention_mask=attention_mask,
691+
position_ids=position_ids,
692+
past_key_value=past_key_values,
693+
output_attentions=output_attentions,
694+
use_cache=use_cache,
695+
cache_position=cache_position,
696+
position_embeddings=position_embeddings,
697+
input_lens=input_lens,
698+
**kwargs,
699+
)
700+
701+
hidden_states = layer_outputs[0]
702+
703+
if output_attentions:
704+
all_self_attns += (layer_outputs[1],)
705+
706+
hidden_states = self.norm(hidden_states)
707+
708+
if hidden_states.shape[0] != batch_size * seq_length:
709+
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
710+
hidden_states = hidden_states_copy
711+
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
712+
# add hidden states from the last decoder layer
713+
if output_hidden_states:
714+
all_hidden_states += (hidden_states,)
715+
716+
output = BaseModelOutputWithPast(
717+
last_hidden_state=hidden_states,
718+
past_key_values=past_key_values if use_cache else None,
719+
hidden_states=all_hidden_states,
720+
attentions=all_self_attns,
721+
)
722+
return output if return_dict else output.to_tuple()
723+
724+
606725
class _IPEXAttention(nn.Module):
607726
def __init__(self, module, config) -> None:
608727
super().__init__()
@@ -618,8 +737,10 @@ def __init__(self, module, config) -> None:
618737
def qkv_gemm(self, hidden_states):
619738
raise NotImplementedError("Need to implement in specific model class")
620739

621-
def rope(self, *args, **kwargs):
622-
raise NotImplementedError("Need to implement in specific model class")
740+
def rope(self, query, key, **kwargs):
741+
position_embeddings = kwargs.pop("position_embeddings", None)
742+
rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
743+
return query, key
623744

624745
def postprocess_attention_output(self, attn_output):
625746
if self.use_sdpa:
@@ -748,13 +869,13 @@ class _IPEXLlamaAttention(_IPEXAttention):
748869
def __init__(self, module, config) -> None:
749870
super().__init__(module, config)
750871
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
751-
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
872+
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias is not None]
752873
use_bias = bias_list != []
753874
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
754875
self.concat_qkv.weight = nn.Parameter(concat_weight)
755876
if use_bias:
756877
concat_bias = torch.concat(bias_list, 0).contiguous()
757-
self.concat_linear.bias = nn.Parameter(concat_bias)
878+
self.concat_qkv.bias = nn.Parameter(concat_bias)
758879
self.q_slice = self.q_proj.weight.shape[0]
759880
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
760881
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
@@ -774,11 +895,6 @@ def qkv_gemm(self, hidden_states):
774895

775896
return query, key, value
776897

777-
def rope(self, query, key, **kwargs):
778-
position_embeddings = kwargs.pop("position_embeddings", None)
779-
rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
780-
return query, key
781-
782898

783899
class _IPEXFalconAttention(_IPEXAttention):
784900
def __init__(self, module, config):
@@ -801,11 +917,6 @@ def qkv_gemm(self, hidden_states):
801917
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
802918
return query, key, value
803919

804-
def rope(self, query, key, **kwargs):
805-
position_embeddings = kwargs.pop("position_embeddings", None)
806-
rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
807-
return query, key
808-
809920

810921
class _IPEXGPT2Attention(_IPEXAttention):
811922
def __init__(self, module, config) -> None:
@@ -1006,6 +1117,12 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
10061117
return outputs
10071118

10081119

1120+
# Currently can just apply llama decoder layer.
1121+
class _IPEXQwen2DecoderLayer(_IPEXLlamaDecoderLayer):
1122+
def __init__(self, *args, **kwargs):
1123+
super().__init__(*args, **kwargs)
1124+
1125+
10091126
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
10101127
class _IPEXIntermediate(nn.Module):
10111128
def __init__(self, module, config):

‎optimum/intel/ipex/modeling_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@
5858
logger = logging.getLogger(__name__)
5959

6060

61-
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
61+
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2")
6262
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
6363
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
6464
# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
65-
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2")
65+
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2")
6666

6767

6868
def _is_patched_with_ipex(model, task, use_cache: bool = True):

‎tests/ipex/test_modeling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
240240
"distilgpt2",
241241
"mpt",
242242
"opt",
243+
"qwen2",
243244
)
244-
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2")
245+
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2", "qwen2")
245246
GENERATION_LENGTH = 100
246247
SPEEDUP_CACHE = 1.0
247248

‎tests/ipex/test_pipelines.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class PipelinesIntegrationTest(unittest.TestCase):
6666
"mistral",
6767
"mpt",
6868
"opt",
69+
"qwen2",
6970
)
7071
QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = (
7172
"bert",
@@ -143,11 +144,10 @@ def test_text_generation_pipeline_inference(self, model_arch):
143144
ipex_generator = ipex_pipeline(
144145
"text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE
145146
)
146-
inputs = "Describe a real-world application of AI."
147-
with torch.inference_mode():
148-
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10)
149-
with torch.inference_mode():
150-
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10)
147+
inputs = "This is a sample"
148+
max_new_tokens = 6
149+
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
150+
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
151151
self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM))
152152
self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"])
153153

‎tests/ipex/utils_tests.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"mt5": "stas/mt5-tiny-random",
5151
"opt": "hf-internal-testing/tiny-random-OPTModel",
5252
"phi": "echarlaix/tiny-random-PhiForCausalLM",
53+
"qwen2": "Jiqing/tiny-random-Qwen2",
5354
"resnet": "hf-internal-testing/tiny-random-resnet",
5455
"roberta": "hf-internal-testing/tiny-random-roberta",
5556
"roformer": "hf-internal-testing/tiny-random-roformer",
@@ -64,4 +65,5 @@
6465
"patched_falcon": "Intel/tiny-random-falcon_ipex_model",
6566
"patched_gpt2": "Intel/tiny-random-gpt2_ipex_model",
6667
"patched_llama2": "Intel/tiny-random-llama2_ipex_model",
68+
"patched_qwen2": "Jiqing/tiny-random-Qwen2_ipex_model",
6769
}

0 commit comments

Comments
 (0)
Please sign in to comment.