diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py
index 3d90ad12fb..6f558018e0 100644
--- a/optimum/exporters/ipex/model_patcher.py
+++ b/optimum/exporters/ipex/model_patcher.py
@@ -20,6 +20,11 @@
     LlamaModel,
     LlamaRMSNorm,
 )
+from transformers.models.qwen2.modeling_qwen2 import (
+    Qwen2DecoderLayer,
+    Qwen2Model,
+    Qwen2RMSNorm,
+)
 from transformers.models.vit.modeling_vit import ViTIntermediate
 
 from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version
@@ -36,7 +41,9 @@
     _IPEXGPT2Attention,
     _IPEXIntermediate,
     _IPEXLlamaDecoderLayer,
+    _IPEXQwen2DecoderLayer,
     _llama_model_forward,
+    _qwen2_model_forward,
 )
 
 
@@ -116,6 +123,18 @@ def _patch_gpt2_model(model):
     return model
 
 
+def _patch_qwen2_model(model):
+    """
+    Patch qwen2 model:
+        1. Use IPEX rope and paged cache
+        2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
+    """
+    convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward)
+    convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward)
+    convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config)
+    return model
+
+
 def _patch_bert_model(model):
     """
     Patch bert model:
@@ -149,6 +168,8 @@ def _patch_model(model):
         model = _patch_falcon_model(model)
     elif model.config.model_type == "gpt2":
         model = _patch_gpt2_model(model)
+    elif model.config.model_type == "qwen2":
+        model = _patch_qwen2_model(model)
     elif model.config.model_type == "bert":
         model = _patch_bert_model(model)
     elif model.config.model_type == "vit":
diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py
index 41dd5693df..2b440aa91a 100755
--- a/optimum/exporters/ipex/modeling_utils.py
+++ b/optimum/exporters/ipex/modeling_utils.py
@@ -603,6 +603,125 @@ def _gpt2_block_forward(
     return outputs  # hidden_states, present, (attentions, cross_attentions)
 
 
+# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499
+def _qwen2_model_forward(
+    self,
+    input_ids: torch.LongTensor = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[Cache] = None,
+    inputs_embeds: Optional[torch.FloatTensor] = None,
+    use_cache: Optional[bool] = None,
+    output_attentions: Optional[bool] = None,
+    output_hidden_states: Optional[bool] = None,
+    return_dict: Optional[bool] = None,
+    cache_position: Optional[torch.LongTensor] = None,
+    **kwargs,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+    output_hidden_states = (
+        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+    )
+    use_cache = use_cache if use_cache is not None else self.config.use_cache
+    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+    if (input_ids is None) ^ (inputs_embeds is not None):
+        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+    if self.gradient_checkpointing and self.training and use_cache:
+        logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
+        use_cache = False
+
+    if inputs_embeds is None:
+        inputs_embeds = self.embed_tokens(input_ids)
+
+    batch_size, seq_length = inputs_embeds.shape[:2]
+
+    past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+    if cache_position is None:
+        cache_position = torch.arange(
+            past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
+        )
+
+    if position_ids is None:
+        device = input_ids.device if input_ids is not None else inputs_embeds.device
+        position_ids = torch.arange(
+            past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+        )
+        position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
+
+    causal_mask = self._update_causal_mask(
+        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+    )
+
+    hidden_states = inputs_embeds
+
+    # create position embeddings to be shared across the decoder layers
+    position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+    input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
+
+    if past_key_values_length == 0 and past_key_values is not None:
+        # first token, remove the padding from hidden_states, varlen do not accept attention mask
+        hidden_states_copy = hidden_states
+        index = attention_mask.view(-1) != 0
+        hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
+        cos = position_embeddings[0]
+        sin = position_embeddings[1]
+        cos = (cos.reshape(-1, cos.shape[-1]))[index]
+        sin = (sin.reshape(-1, sin.shape[-1]))[index]
+        position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
+    else:
+        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+
+    if past_key_values is None:
+        attention_mask = causal_mask
+
+    # decoder layers
+    all_hidden_states = () if output_hidden_states else None
+    all_self_attns = () if output_attentions else None
+
+    for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        layer_outputs = decoder_layer(
+            hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_values,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            position_embeddings=position_embeddings,
+            input_lens=input_lens,
+            **kwargs,
+        )
+
+        hidden_states = layer_outputs[0]
+
+        if output_attentions:
+            all_self_attns += (layer_outputs[1],)
+
+    hidden_states = self.norm(hidden_states)
+
+    if hidden_states.shape[0] != batch_size * seq_length:
+        (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
+        hidden_states = hidden_states_copy
+    hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+    # add hidden states from the last decoder layer
+    if output_hidden_states:
+        all_hidden_states += (hidden_states,)
+
+    output = BaseModelOutputWithPast(
+        last_hidden_state=hidden_states,
+        past_key_values=past_key_values if use_cache else None,
+        hidden_states=all_hidden_states,
+        attentions=all_self_attns,
+    )
+    return output if return_dict else output.to_tuple()
+
+
 class _IPEXAttention(nn.Module):
     def __init__(self, module, config) -> None:
         super().__init__()
@@ -618,8 +737,10 @@ def __init__(self, module, config) -> None:
     def qkv_gemm(self, hidden_states):
         raise NotImplementedError("Need to implement in specific model class")
 
-    def rope(self, *args, **kwargs):
-        raise NotImplementedError("Need to implement in specific model class")
+    def rope(self, query, key, **kwargs):
+        position_embeddings = kwargs.pop("position_embeddings", None)
+        rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
+        return query, key
 
     def postprocess_attention_output(self, attn_output):
         if self.use_sdpa:
@@ -748,13 +869,13 @@ class _IPEXLlamaAttention(_IPEXAttention):
     def __init__(self, module, config) -> None:
         super().__init__(module, config)
         concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
-        bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
+        bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias is not None]
         use_bias = bias_list != []
         self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
         self.concat_qkv.weight = nn.Parameter(concat_weight)
         if use_bias:
             concat_bias = torch.concat(bias_list, 0).contiguous()
-            self.concat_linear.bias = nn.Parameter(concat_bias)
+            self.concat_qkv.bias = nn.Parameter(concat_bias)
         self.q_slice = self.q_proj.weight.shape[0]
         self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
         self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
@@ -774,11 +895,6 @@ def qkv_gemm(self, hidden_states):
 
         return query, key, value
 
-    def rope(self, query, key, **kwargs):
-        position_embeddings = kwargs.pop("position_embeddings", None)
-        rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
-        return query, key
-
 
 class _IPEXFalconAttention(_IPEXAttention):
     def __init__(self, module, config):
@@ -801,11 +917,6 @@ def qkv_gemm(self, hidden_states):
             value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
         return query, key, value
 
-    def rope(self, query, key, **kwargs):
-        position_embeddings = kwargs.pop("position_embeddings", None)
-        rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True)
-        return query, key
-
 
 class _IPEXGPT2Attention(_IPEXAttention):
     def __init__(self, module, config) -> None:
@@ -1006,6 +1117,12 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
         return outputs
 
 
+# Currently can just apply llama decoder layer.
+class _IPEXQwen2DecoderLayer(_IPEXLlamaDecoderLayer):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
 # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
 class _IPEXIntermediate(nn.Module):
     def __init__(self, module, config):
diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py
index 3263e31db3..76746c1714 100644
--- a/optimum/intel/ipex/modeling_base.py
+++ b/optimum/intel/ipex/modeling_base.py
@@ -58,11 +58,11 @@
 logger = logging.getLogger(__name__)
 
 
-_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
+_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2")
 _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
 _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
 # TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
-_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2")
+_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2")
 
 
 def _is_patched_with_ipex(model, task, use_cache: bool = True):
diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py
index 419e1bb42a..a91d08ee0e 100644
--- a/tests/ipex/test_modeling.py
+++ b/tests/ipex/test_modeling.py
@@ -233,8 +233,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
         "distilgpt2",
         "mpt",
         "opt",
+        "qwen2",
     )
-    IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2")
+    IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2", "qwen2")
     GENERATION_LENGTH = 100
     SPEEDUP_CACHE = 1.0
 
diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py
index f376c6050a..6b94a4dc02 100644
--- a/tests/ipex/test_pipelines.py
+++ b/tests/ipex/test_pipelines.py
@@ -66,6 +66,7 @@ class PipelinesIntegrationTest(unittest.TestCase):
         "mistral",
         "mpt",
         "opt",
+        "qwen2",
     )
     QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = (
         "bert",
@@ -143,11 +144,10 @@ def test_text_generation_pipeline_inference(self, model_arch):
         ipex_generator = ipex_pipeline(
             "text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE
         )
-        inputs = "Describe a real-world application of AI."
-        with torch.inference_mode():
-            transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10)
-        with torch.inference_mode():
-            ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10)
+        inputs = "This is a sample"
+        max_new_tokens = 6
+        transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
+        ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
         self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM))
         self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"])
 
diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py
index 8cd93516da..72f407cc13 100644
--- a/tests/ipex/utils_tests.py
+++ b/tests/ipex/utils_tests.py
@@ -50,6 +50,7 @@
     "mt5": "stas/mt5-tiny-random",
     "opt": "hf-internal-testing/tiny-random-OPTModel",
     "phi": "echarlaix/tiny-random-PhiForCausalLM",
+    "qwen2": "Jiqing/tiny-random-Qwen2",
     "resnet": "hf-internal-testing/tiny-random-resnet",
     "roberta": "hf-internal-testing/tiny-random-roberta",
     "roformer": "hf-internal-testing/tiny-random-roformer",
@@ -64,4 +65,5 @@
     "patched_falcon": "Intel/tiny-random-falcon_ipex_model",
     "patched_gpt2": "Intel/tiny-random-gpt2_ipex_model",
     "patched_llama2": "Intel/tiny-random-llama2_ipex_model",
+    "patched_qwen2": "Jiqing/tiny-random-Qwen2_ipex_model",
 }