Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable quant model support #1074

Merged
merged 65 commits into from
Feb 19, 2025
Merged
Changes from 1 commit
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
3888824
enable IPEXModelForSeq2SeqLM
jiqing-feng Dec 9, 2024
f9fa807
set static cache
jiqing-feng Dec 9, 2024
202df43
add tests for IPEXModelForSeq2SeqLM
jiqing-feng Dec 9, 2024
4488073
add docs
jiqing-feng Dec 9, 2024
16fecf8
fix readme
jiqing-feng Dec 9, 2024
de501f4
Merge branch 'main' into text2text
jiqing-feng Dec 10, 2024
4225bf0
refactor compile
jiqing-feng Dec 11, 2024
2ac7ecf
fix check
jiqing-feng Dec 11, 2024
24b988c
fix ruff check
jiqing-feng Dec 11, 2024
5c4f9a1
Merge branch 'huggingface:main' into text2text
jiqing-feng Dec 16, 2024
46b93a4
enable quantized model
jiqing-feng Dec 16, 2024
82d39ce
add bnb test
jiqing-feng Dec 16, 2024
7dc08da
add bnb tests in yaml
jiqing-feng Dec 16, 2024
30027ff
fix tests
jiqing-feng Dec 16, 2024
314db04
disable bnb tests
jiqing-feng Dec 16, 2024
87656ca
fix gpt2
jiqing-feng Dec 16, 2024
9a7e931
Merge branch 'main' into quant
jiqing-feng Dec 18, 2024
b0cec9c
set actual device
jiqing-feng Dec 18, 2024
94cf35d
assign device when convert class
jiqing-feng Dec 18, 2024
9af46d1
fix class init
jiqing-feng Dec 18, 2024
18b2a6a
fix ipex attn init
jiqing-feng Dec 18, 2024
9f6db33
rm set device on config
jiqing-feng Dec 18, 2024
6d8a969
fix format
jiqing-feng Dec 18, 2024
dd811f9
fix mlp class init
jiqing-feng Dec 18, 2024
d91eefb
Merge branch 'huggingface:main' into quant
jiqing-feng Jan 14, 2025
f094cad
Merge branch 'main' into quant
jiqing-feng Jan 21, 2025
dab4a78
add use_cache param when init generation config
jiqing-feng Jan 21, 2025
6bf3b8b
fix gpt2 quant model
jiqing-feng Jan 21, 2025
356d51d
fix falcon linear fusion
jiqing-feng Jan 22, 2025
d1eee87
fix falcon
jiqing-feng Jan 22, 2025
3aece6a
Merge branch 'huggingface:main' into quant
jiqing-feng Feb 7, 2025
57e3c27
enable awq model test
jiqing-feng Feb 7, 2025
8f6ba5c
fix install
jiqing-feng Feb 7, 2025
8870714
fix install
jiqing-feng Feb 7, 2025
5828fc0
fix install
jiqing-feng Feb 7, 2025
c616d57
fix install
jiqing-feng Feb 7, 2025
e1715b8
fix install
jiqing-feng Feb 7, 2025
e88faf2
fix install
jiqing-feng Feb 7, 2025
882f2b2
fix install
jiqing-feng Feb 7, 2025
d8208c7
fix install
jiqing-feng Feb 7, 2025
80b9ccb
fix install
jiqing-feng Feb 7, 2025
f05fb2f
fix install
jiqing-feng Feb 7, 2025
bd8e870
fix install
jiqing-feng Feb 7, 2025
e471be3
fix install
jiqing-feng Feb 7, 2025
96f4622
fix install
jiqing-feng Feb 7, 2025
32bf0a1
fix install
jiqing-feng Feb 7, 2025
ad3467b
fix install
jiqing-feng Feb 7, 2025
4a21d26
fix install
jiqing-feng Feb 7, 2025
fb8002c
fix install
jiqing-feng Feb 7, 2025
2ad2371
fix install
jiqing-feng Feb 7, 2025
3c2ddef
enable bnb test
jiqing-feng Feb 7, 2025
757ea8c
remove useless device
jiqing-feng Feb 7, 2025
0a6ab0f
update python to 3.10 on test_ipex
jiqing-feng Feb 11, 2025
8c4884b
Apply suggestions from code review
IlyasMoutawwakil Feb 11, 2025
7fa23a5
install autoawq
jiqing-feng Feb 11, 2025
5386bbe
install wheel
jiqing-feng Feb 11, 2025
c5f5d16
fix install autoawq
jiqing-feng Feb 11, 2025
41513f0
rm autoawq
jiqing-feng Feb 11, 2025
8e1caa2
rebase
jiqing-feng Feb 12, 2025
f73c08d
fix concat qkv
jiqing-feng Feb 12, 2025
f51777b
fix format
jiqing-feng Feb 12, 2025
f64b251
fix qwen patch
jiqing-feng Feb 12, 2025
778bf15
fix bias
jiqing-feng Feb 12, 2025
6ba1895
rm autoawq test
jiqing-feng Feb 12, 2025
88dba29
fix style
jiqing-feng Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
enable quantized model
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
jiqing-feng committed Dec 16, 2024
commit 46b93a4c695b695021578b5d8f13eb9a3edd562b
134 changes: 89 additions & 45 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
logger = logging.getLogger(__name__)

_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
_accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "musa"]


if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
@@ -133,6 +134,32 @@ def forward(self, x, y, z):
return x


# Adapted from https://github.com/huggingface/accelerate/blob/v1.2.1/src/accelerate/hooks.py#L183
def _remove_hooks_for_ipex(module, recurse):
if hasattr(module, "_hf_hook"):
module._hf_hook.detach_hook(module)
delattr(module, "_hf_hook")

if hasattr(module, "_old_forward"):
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = module.__class__.forward.__get__(module)
else:
module.forward = module.__class__.forward.__get__(module)
delattr(module, "_old_forward")

# Remove accelerate added warning hooks from dispatch_model
for attr in _accelerate_added_attributes:
module.__dict__.pop(attr, None)

if recurse:
for child in module.children():
_remove_hooks_for_ipex(child, recurse)

return module


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
def _ipex_rms_layer_norm_forward(self, hidden_states):
return rms_norm(hidden_states, self.weight, self.variance_epsilon)
@@ -656,30 +683,36 @@ def forward(
class _IPEXLlamaAttention(_IPEXAttention):
def __init__(self, module, config) -> None:
super().__init__(module, config)
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0).contiguous()
self.concat_linear.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
if self.module_device.type == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)

elif self.module_device.type == "xpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = XPULinearAdd(module.o_proj)
if getattr(config, "quantization_config", None) is None:
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0).contiguous()
self.concat_linear.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
if self.module_device.type == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)

elif self.module_device.type == "xpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = XPULinearAdd(module.o_proj)

def qkv_gemm(self, hidden_states):
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
if hasattr(self, "concat_qkv"):
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
else:
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)

return query, key, value

@@ -745,16 +778,17 @@ def __init__(self, module, config) -> None:
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
elif self.module_device.type == "xpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = XPULinearAdd(module.down_proj)
self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj)
if getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
elif self.module_device.type == "xpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = XPULinearAdd(module.down_proj)
self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj)

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
if hasattr(self, "linear_silu_mul"):
@@ -776,17 +810,18 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
if getattr(config, "quantization_config", None) is None:
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if self.module_device.type == "cpu":
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
elif self.module_device.type == "xpu":
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
if self.module_device.type == "cpu":
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
elif self.module_device.type == "xpu":
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)

def forward(
self,
@@ -812,6 +847,8 @@ def __init__(self, module, config):
_setattr_from_module(self, module)
self.self_attn = _IPEXLlamaAttention(module.self_attn, config)
self.mlp = _IPEXLlamaMLP(module.mlp, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)

def forward(self, hidden_states: torch.Tensor, **kwargs):
# Please see the original model's forward to check the parameter
@@ -845,6 +882,8 @@ def __init__(self, module, config):
_setattr_from_module(self, module)
self.self_attention = _IPEXFalconAttention(module.self_attention, config)
self.mlp = _IPEXFalconMLP(module.mlp, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)

def forward(self, hidden_states: torch.Tensor, **kwargs):
# Please see the original model's forward to check the parameter
@@ -871,11 +910,16 @@ def __init__(self, module, config):
super().__init__()
_setattr_from_module(self, module)
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense)
if getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_gelu(hidden_states)
if hasattr(self, "linear_gelu"):
hidden_states = self.linear_gelu(hidden_states)
else:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
1 change: 1 addition & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
@@ -189,6 +189,7 @@ def maybe_apply_torch_compile(self):
not self.model.device.type != "cpu"
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
or getattr(self.config, "quantization_config", None)
):
return
if self.use_cache and not self._supports_static_cache: