Skip to content

Commit e3b970c

Browse files
committed
disable linear fusion when using compile
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 7f70f2b commit e3b970c

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

optimum/exporters/ipex/modeling_utils.py

+15-20
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _llama_model_forward(
229229
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
230230
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
231231
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
232-
max_input_lens = input_lens.max().item()
232+
max_input_lens = input_lens.max()
233233

234234
if past_key_values_length == 0 and past_key_values is not None:
235235
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -357,7 +357,7 @@ def _falcon_model_forward(
357357
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
358358
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
359359
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
360-
max_input_lens = input_lens.max().item()
360+
max_input_lens = input_lens.max()
361361

362362
if past_key_values_length == 0 and past_key_values is not None:
363363
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -499,7 +499,7 @@ def _gpt2_model_forward(
499499
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
500500
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
501501
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
502-
max_input_lens = input_lens.max().item()
502+
max_input_lens = input_lens.max()
503503

504504
if past_length == 0 and past_key_values is not None:
505505
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -635,7 +635,7 @@ def _qwen2_model_forward(
635635
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
636636
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
637637
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
638-
max_input_lens = input_lens.max().item()
638+
max_input_lens = input_lens.max()
639639

640640
if past_key_values_length == 0 and past_key_values is not None:
641641
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -752,11 +752,11 @@ def attention_interface(
752752
if past_key_value is None:
753753
n_rep = query.shape[1] // key.shape[1]
754754
attn_output = torch.nn.functional.scaled_dot_product_attention(
755-
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
756-
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
755+
query.reshape(input_lens.shape[0], input_lens.max(), -1, query.shape[-1]).transpose(1, 2),
756+
key.reshape(input_lens.shape[0], input_lens.max(), -1, key.shape[-1])
757757
.transpose(1, 2)
758758
.repeat_interleave(n_rep, 1),
759-
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
759+
value.reshape(input_lens.shape[0], input_lens.max(), -1, value.shape[-1])
760760
.transpose(1, 2)
761761
.repeat_interleave(n_rep, 1),
762762
attn_mask=attention_mask,
@@ -883,13 +883,11 @@ def __init__(self, module, device, config) -> None:
883883
self.q_slice = self.q_proj.weight.shape[0]
884884
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
885885
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
886-
if self.module_device.type == "cpu":
887-
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
886+
if not config.compile and module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
887+
if self.module_device.type == "cpu":
888888
self.mha_linear_add = LinearAdd(module.o_proj)
889-
890889
elif self.module_device.type == "xpu":
891-
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
892-
self.mha_linear_add = XPULinearAdd(module.o_proj)
890+
self.mha_linear_add = XPULinearAdd(module.o_proj)
893891

894892
def qkv_gemm(self, hidden_states):
895893
if hasattr(self, "concat_qkv"):
@@ -932,7 +930,7 @@ def __init__(self, module, device, config) -> None:
932930
self.num_key_value_heads = config.num_key_value_heads
933931
super().__init__(module, device, config)
934932
_setattr_from_module(self, module)
935-
if getattr(config, "quantization_config", None) is None:
933+
if not config.compile and getattr(config, "quantization_config", None) is None:
936934
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
937935
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
938936
self.c_attn_linear.bias = self.c_attn.bias
@@ -976,7 +974,7 @@ def __init__(self, module, device, config) -> None:
976974
_setattr_from_module(self, module)
977975
self.config = config
978976
self.module_device = device
979-
if getattr(config, "quantization_config", None) is None:
977+
if not config.compile and getattr(config, "quantization_config", None) is None:
980978
if self.module_device.type == "cpu":
981979
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
982980
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
@@ -1009,7 +1007,7 @@ def __init__(self, module, device, config) -> None:
10091007
_setattr_from_module(self, module)
10101008
self.config = config
10111009
self.module_device = device
1012-
if getattr(config, "quantization_config", None) is None:
1010+
if not config.compile and getattr(config, "quantization_config", None) is None:
10131011
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
10141012
if self.module_device.type == "cpu":
10151013
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
@@ -1049,7 +1047,7 @@ def __init__(self, module, device, config) -> None:
10491047
self.config = config
10501048
self.module_device = device
10511049

1052-
if getattr(config, "quantization_config", None) is None:
1050+
if not config.compile and getattr(config, "quantization_config", None) is None:
10531051
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
10541052
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
10551053
self.c_fc_linear.bias = self.c_fc.bias
@@ -1058,11 +1056,8 @@ def __init__(self, module, device, config) -> None:
10581056
self.c_proj_linear.bias = self.c_proj.bias
10591057
if self.module_device.type == "cpu":
10601058
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)
1061-
1062-
if self.module_device.type == "cpu":
10631059
if self.c_proj_linear not in ["LinearAllreduce"]:
10641060
self.linear_add = LinearAdd(self.c_proj_linear)
1065-
10661061
elif self.module_device.type == "xpu":
10671062
if self.c_proj_linear not in ["LinearAllreduce"]:
10681063
self.linear_add = XPULinearAdd(self.c_proj_linear)
@@ -1234,7 +1229,7 @@ def __init__(self, module, device, config):
12341229
super().__init__()
12351230
_setattr_from_module(self, module)
12361231
self.module_device = device
1237-
if getattr(config, "quantization_config", None) is None:
1232+
if not config.compile and getattr(config, "quantization_config", None) is None:
12381233
if self.module_device.type == "cpu":
12391234
self.linear_gelu = LinearGelu(module.dense)
12401235
elif self.module_device.type == "xpu":

optimum/intel/ipex/modeling_base.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def __init__(
135135
self.use_cache = kwargs.get("use_cache", False)
136136
self.model_save_dir = model_save_dir
137137
self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache)
138-
self.compiled = False
138+
self.compile = self.can_compile()
139+
model.config.compile = compile
139140

140141
self.input_names = set(inspect.signature(model.forward).parameters)
141142

@@ -147,9 +148,10 @@ def __init__(
147148
if hasattr(self.auto_model_class, "register"):
148149
self.auto_model_class.register(AutoConfig, self.__class__)
149150

150-
self.maybe_apply_torch_compile()
151+
if self.compile:
152+
self.apply_torch_compile()
151153

152-
if warmup and not self.compiled:
154+
if warmup and not self.compile:
153155
self._init_warmup()
154156

155157
@classmethod
@@ -220,24 +222,27 @@ def to(self, device: Union[torch.device, str]):
220222
def can_generate(self):
221223
return isinstance(self, GenerationMixin)
222224

223-
def maybe_apply_torch_compile(self):
225+
def can_compile(self):
224226
if (
225227
self.model.device.type != "cpu"
226228
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
227229
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
228230
or getattr(self.config, "quantization_config", None)
229231
):
230-
return
232+
return False
231233
if self.use_cache and not self._supports_static_cache:
232-
return
234+
return False
235+
236+
return True
237+
238+
def apply_torch_compile(self):
233239
from torch._inductor import config as inductor_config
234240

235241
# System level optimization
236242
inductor_config.cpp_wrapper = True
237243
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
238244
logger.info("Enable torch.compile optimization")
239245
self.model.forward = torch.compile(self.model.forward)
240-
self.compiled = True
241246

242247
def _init_warmup(self):
243248
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
@@ -317,7 +322,7 @@ def __init__(
317322
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
318323
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
319324

320-
if warmup and not self.compiled:
325+
if warmup and not self.compile:
321326
self._init_warmup()
322327

323328
@torch.no_grad()
@@ -337,7 +342,7 @@ def _prepare_generation_config(
337342
kwargs["use_cache"] = self.use_cache
338343
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
339344
generation_method = generation_config.get_generation_mode().value
340-
if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
345+
if self.compile and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
341346
# Use static cache for torch compile
342347
generation_config.cache_implementation = "static"
343348
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
@@ -448,7 +453,7 @@ def __init__(
448453
if hasattr(self.model_cls, "_convert_to_standard_cache"):
449454
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
450455

451-
if warmup and not self.compiled:
456+
if warmup and not self.compile:
452457
self._init_warmup()
453458

454459
@torch.no_grad()
@@ -465,7 +470,7 @@ def _prepare_generation_config(
465470
) -> Tuple[GenerationConfig, Dict]:
466471
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
467472
# Use static cache for torch.compile
468-
if self.compiled:
473+
if self.compile:
469474
generation_config.cache_implementation = "static"
470475

471476
return generation_config, model_kwargs

0 commit comments

Comments
 (0)