diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e0c7e0b3cc..c7984c7367 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -229,7 +229,7 @@ def _llama_model_forward( input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() - max_input_lens = input_lens.max().item() + max_input_lens = input_lens.max() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -357,7 +357,7 @@ def _falcon_model_forward( input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() - max_input_lens = input_lens.max().item() + max_input_lens = input_lens.max() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -499,7 +499,7 @@ def _gpt2_model_forward( input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() - max_input_lens = input_lens.max().item() + max_input_lens = input_lens.max() if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -635,7 +635,7 @@ def _qwen2_model_forward( input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() - max_input_lens = input_lens.max().item() + max_input_lens = input_lens.max() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -754,11 +754,11 @@ def attention_interface( if past_key_value is None: n_rep = query.shape[1] // key.shape[1] attn_output = torch.nn.functional.scaled_dot_product_attention( - query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2), - key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]) + query.reshape(input_lens.shape[0], input_lens.max(), -1, query.shape[-1]).transpose(1, 2), + key.reshape(input_lens.shape[0], input_lens.max(), -1, key.shape[-1]) .transpose(1, 2) .repeat_interleave(n_rep, 1), - value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]) + value.reshape(input_lens.shape[0], input_lens.max(), -1, value.shape[-1]) .transpose(1, 2) .repeat_interleave(n_rep, 1), attn_mask=attention_mask, @@ -885,13 +885,11 @@ def __init__(self, module, device, config) -> None: self.q_slice = self.q_proj.weight.shape[0] self.k_slice = self.q_slice + self.k_proj.weight.shape[0] self.v_slice = self.k_slice + self.v_proj.weight.shape[0] - if self.module_device.type == "cpu": - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + if not config.compile and module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + if self.module_device.type == "cpu": self.mha_linear_add = LinearAdd(module.o_proj) - elif self.module_device.type == "xpu": - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mha_linear_add = XPULinearAdd(module.o_proj) + self.mha_linear_add = XPULinearAdd(module.o_proj) def qkv_gemm(self, hidden_states): if hasattr(self, "concat_qkv"): @@ -935,7 +933,7 @@ class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, device, config) -> None: super().__init__(module, device, config) _setattr_from_module(self, module) - if getattr(config, "quantization_config", None) is None: + if not config.compile and getattr(config, "quantization_config", None) is None: self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) self.c_attn_linear.bias = self.c_attn.bias @@ -979,7 +977,7 @@ def __init__(self, module, device, config) -> None: _setattr_from_module(self, module) self.config = config self.module_device = device - if getattr(config, "quantization_config", None) is None: + if not config.compile and getattr(config, "quantization_config", None) is None: if self.module_device.type == "cpu": # LinearAllreduce and LinearLayer cannot use fused op LinearAdd if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: @@ -1012,7 +1010,7 @@ def __init__(self, module, device, config) -> None: _setattr_from_module(self, module) self.config = config self.module_device = device - if getattr(config, "quantization_config", None) is None: + if not config.compile and getattr(config, "quantization_config", None) is None: # LinearAllreduce and LinearLayer cannot use fused op LinearAdd if self.module_device.type == "cpu": self.linear_gelu = LinearGelu(module.dense_h_to_4h) @@ -1052,7 +1050,7 @@ def __init__(self, module, device, config) -> None: self.config = config self.module_device = device - if getattr(config, "quantization_config", None) is None: + if not config.compile and getattr(config, "quantization_config", None) is None: self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) self.c_fc_linear.bias = self.c_fc.bias @@ -1061,11 +1059,8 @@ def __init__(self, module, device, config) -> None: self.c_proj_linear.bias = self.c_proj.bias if self.module_device.type == "cpu": self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) - - if self.module_device.type == "cpu": if self.c_proj_linear not in ["LinearAllreduce"]: self.linear_add = LinearAdd(self.c_proj_linear) - elif self.module_device.type == "xpu": if self.c_proj_linear not in ["LinearAllreduce"]: self.linear_add = XPULinearAdd(self.c_proj_linear) @@ -1237,7 +1232,7 @@ def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) self.module_device = device - if getattr(config, "quantization_config", None) is None: + if not config.compile and getattr(config, "quantization_config", None) is None: if self.module_device.type == "cpu": self.linear_gelu = LinearGelu(module.dense) elif self.module_device.type == "xpu": diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index c27a87bdba..02f158b8b7 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -146,7 +146,8 @@ def __init__( self.use_cache = kwargs.get("use_cache", False) self.model_save_dir = model_save_dir self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache) - self.compiled = False + self.compile = self.can_compile() + model.config.compile = self.compile self.input_names = set(inspect.signature(model.forward).parameters) @@ -158,9 +159,10 @@ def __init__( if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) - self.maybe_apply_torch_compile() + if self.compile: + self.apply_torch_compile() - if warmup and not self.compiled: + if warmup and not self.compile: self._init_warmup() @classmethod @@ -231,16 +233,20 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) - def maybe_apply_torch_compile(self): + def can_compile(self): if ( self.model.device.type != "cpu" or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE) or getattr(self.config, "quantization_config", None) ): - return + return False if self.use_cache and not self._supports_static_cache: - return + return False + + return True + + def apply_torch_compile(self): from torch._inductor import config as inductor_config # System level optimization @@ -248,7 +254,6 @@ def maybe_apply_torch_compile(self): os.environ["TORCHINDUCTOR_FREEZING"] = "1" logger.info("Enable torch.compile optimization") self.model.forward = torch.compile(self.model.forward) - self.compiled = True def _init_warmup(self): inputs = prepare_jit_inputs(self.model, self.export_feature, False) @@ -328,7 +333,7 @@ def __init__( if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache - if warmup and not self.compiled: + if warmup and not self.compile: self._init_warmup() @torch.no_grad() @@ -348,7 +353,7 @@ def _prepare_generation_config( kwargs["use_cache"] = self.use_cache generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value - if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache: + if self.compile and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache: # Use static cache for torch compile generation_config.cache_implementation = "static" if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: @@ -459,7 +464,7 @@ def __init__( if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache - if warmup and not self.compiled: + if warmup and not self.compile: self._init_warmup() @torch.no_grad() @@ -476,7 +481,7 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) # Use static cache for torch.compile - if self.compiled: + if self.compile: generation_config.cache_implementation = "static" return generation_config, model_kwargs diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 9309e49872..526da70b2e 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -374,7 +374,7 @@ def test_ipex_beam_search(self, test_name, model_arch, use_cache): model_id, use_cache=use_cache, torch_dtype=dtype, device_map=DEVICE ) # It will be removed when torch 2.6 released - if model_arch == "opt" and not use_cache and model.compiled and is_torch_version("<", "2.6.0"): + if model_arch == "opt" and not use_cache and model.compile and is_torch_version("<", "2.6.0"): return if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: self.assertTrue(model.add_patch)