Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch #1200

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

Patch #1200

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
transformers-version: ["4.47.*"]
transformers-version: ["4.49.0"]
torch-version: ["2.6.0"]

runs-on: ubuntu-22.04
Expand Down
6 changes: 4 additions & 2 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import List, Optional, Tuple

import intel_extension_for_pytorch as ipex
import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig
Expand Down Expand Up @@ -38,13 +39,14 @@ def __init__(
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: int,
device,
device=None,
dtype=None,
layer_device_map=None,
**kwargs,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
default_device = torch.device("xpu") if ipex._C._has_xpu() else torch.device("cpu")
device = device or default_device
self.device = device
self._supports_flash_decoding = (
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.47.0"
_TRANSFORMERS_MAX_VERSION = "4.47.99"
_TRANSFORMERS_MIN_VERSION = "4.49.0"
_TRANSFORMERS_MAX_VERSION = "4.49.0"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down
60 changes: 29 additions & 31 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _llama_model_forward(
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()
max_input_lens = input_lens.max()

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -346,8 +346,8 @@ def _falcon_model_forward(

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
# attention_probs has shape batch_size x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = inputs_embeds

Expand All @@ -357,7 +357,7 @@ def _falcon_model_forward(
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()
max_input_lens = input_lens.max()

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -499,7 +499,7 @@ def _gpt2_model_forward(
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()
max_input_lens = input_lens.max()

if past_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -635,7 +635,7 @@ def _qwen2_model_forward(
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
max_input_lens = input_lens.max().item()
max_input_lens = input_lens.max()

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
Expand Down Expand Up @@ -707,7 +707,9 @@ def __init__(self, module, device, config) -> None:
_setattr_from_module(self, module)
self.config = config
self.module_device = device
self.num_groups = self.num_heads // self.num_key_value_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.num_groups = self.num_attention_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
).repeat_interleave(self.num_groups)
Expand Down Expand Up @@ -752,11 +754,11 @@ def attention_interface(
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
query.reshape(input_lens.shape[0], input_lens.max(), -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max(), -1, key.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
value.reshape(input_lens.shape[0], input_lens.max(), -1, value.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
attn_mask=attention_mask,
Expand Down Expand Up @@ -883,22 +885,20 @@ def __init__(self, module, device, config) -> None:
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
if self.module_device.type == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
if not config.compile and module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
if self.module_device.type == "cpu":
self.mha_linear_add = LinearAdd(module.o_proj)

elif self.module_device.type == "xpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = XPULinearAdd(module.o_proj)
self.mha_linear_add = XPULinearAdd(module.o_proj)

def qkv_gemm(self, hidden_states):
if hasattr(self, "concat_qkv"):
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
else:
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
query = self.q_proj(hidden_states).view(-1, self.num_attention_heads, self.head_dim)
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)

Expand All @@ -916,23 +916,24 @@ def __init__(self, module, device, config):
def qkv_gemm(self, hidden_states):
qkv_out = self.query_key_value(hidden_states)
if self.new_decoder_architecture:
qkv_out = qkv_out.view(qkv_out.shape[0], -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
qkv_out = qkv_out.view(
qkv_out.shape[0], -1, self.num_attention_heads // self.num_kv_heads + 2, self.head_dim
)
query = qkv_out[:, :, :-2, :].flatten(1, 2)
key = qkv_out[:, :, [-2], :].flatten(1, 2)
value = qkv_out[:, :, [-1], :].flatten(1, 2)
else:
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
return query, key, value


class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, device, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, device, config)
_setattr_from_module(self, module)
if getattr(config, "quantization_config", None) is None:
if not config.compile and getattr(config, "quantization_config", None) is None:
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
Expand All @@ -952,9 +953,9 @@ def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_heads, self.head_dim)
key = key.view(-1, self.num_heads, self.head_dim)
value = value.view(-1, self.num_heads, self.head_dim)
query = query.view(-1, self.num_attention_heads, self.head_dim)
key = key.view(-1, self.num_attention_heads, self.head_dim)
value = value.view(-1, self.num_attention_heads, self.head_dim)
return query, key, value

def rope(self, query, key, *args, **kwargs):
Expand All @@ -976,7 +977,7 @@ def __init__(self, module, device, config) -> None:
_setattr_from_module(self, module)
self.config = config
self.module_device = device
if getattr(config, "quantization_config", None) is None:
if not config.compile and getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
Expand Down Expand Up @@ -1009,7 +1010,7 @@ def __init__(self, module, device, config) -> None:
_setattr_from_module(self, module)
self.config = config
self.module_device = device
if getattr(config, "quantization_config", None) is None:
if not config.compile and getattr(config, "quantization_config", None) is None:
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
Expand Down Expand Up @@ -1049,7 +1050,7 @@ def __init__(self, module, device, config) -> None:
self.config = config
self.module_device = device

if getattr(config, "quantization_config", None) is None:
if not config.compile and getattr(config, "quantization_config", None) is None:
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
self.c_fc_linear.bias = self.c_fc.bias
Expand All @@ -1058,11 +1059,8 @@ def __init__(self, module, device, config) -> None:
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)

if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)
Expand Down Expand Up @@ -1234,7 +1232,7 @@ def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.module_device = device
if getattr(config, "quantization_config", None) is None:
if not config.compile and getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
elif self.module_device.type == "xpu":
Expand Down
29 changes: 17 additions & 12 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def __init__(
self.use_cache = kwargs.get("use_cache", False)
self.model_save_dir = model_save_dir
self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache)
self.compiled = False
self.compile = self.can_compile()
model.config.compile = self.compile

self.input_names = set(inspect.signature(model.forward).parameters)

Expand All @@ -147,9 +148,10 @@ def __init__(
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)

self.maybe_apply_torch_compile()
if self.compile:
self.apply_torch_compile()

if warmup and not self.compiled:
if warmup and not self.compile:
self._init_warmup()

@classmethod
Expand Down Expand Up @@ -220,24 +222,27 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def maybe_apply_torch_compile(self):
def can_compile(self):
if (
self.model.device.type != "cpu"
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
or getattr(self.config, "quantization_config", None)
):
return
return False
if self.use_cache and not self._supports_static_cache:
return
return False

return True

def apply_torch_compile(self):
from torch._inductor import config as inductor_config

# System level optimization
inductor_config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization")
self.model.forward = torch.compile(self.model.forward)
self.compiled = True
self.model.forward = torch.compile(self.model.forward, mode="max-autotune")

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
Expand Down Expand Up @@ -317,7 +322,7 @@ def __init__(
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache

if warmup and not self.compiled:
if warmup and not self.compile:
self._init_warmup()

@torch.no_grad()
Expand All @@ -337,7 +342,7 @@ def _prepare_generation_config(
kwargs["use_cache"] = self.use_cache
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
generation_method = generation_config.get_generation_mode().value
if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
if self.compile and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
# Use static cache for torch compile
generation_config.cache_implementation = "static"
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
Expand Down Expand Up @@ -448,7 +453,7 @@ def __init__(
if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache

if warmup and not self.compiled:
if warmup and not self.compile:
self._init_warmup()

@torch.no_grad()
Expand All @@ -465,7 +470,7 @@ def _prepare_generation_config(
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
# Use static cache for torch.compile
if self.compiled:
if self.compile:
generation_config.cache_implementation = "static"

return generation_config, model_kwargs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"nncf": ["nncf>=2.14.0"],
"openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"],
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.46,<4.48", "accelerate"],
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.48,<4.50", "accelerate"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
2 changes: 1 addition & 1 deletion tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_ipex_beam_search(self, test_name, model_arch, use_cache):
model_id, use_cache=use_cache, torch_dtype=dtype, device_map=DEVICE
)
# It will be removed when torch 2.6 released
if model_arch == "opt" and not use_cache and model.compiled and is_torch_version("<", "2.6.0"):
if model_arch == "opt" and not use_cache and model.compile and is_torch_version("<", "2.6.0"):
return
if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
self.assertTrue(model.add_patch)
Expand Down
Loading