diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 33fd77cba3..11698a380d 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -293,7 +293,7 @@ def __exit__(self, exc_type, exc_value, traceback): # adopted from # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965 # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058 -def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): +def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): from transformers.modeling_attn_mask_utils import AttentionMaskConverter if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None: @@ -306,10 +306,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling # using minimum from dtype with larger bandwith (floa32) may lead to overflow # during execution on platforms with default lower precision (bfloat16, float16) min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] + # difference with original modeling if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache @@ -321,7 +323,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + # difference with original modeling causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -358,6 +362,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po return causal_mask +# adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036 +def _llama_gemma_update_causal_mask_latest( + self, + attention_mask, + input_tensor, + cache_position, + past_key_values, + output_attentions, +): + from transformers.cache_utils import StaticCache + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference with original modeling + # using minimum from dtype with larger bandwith (floa32) may lead to overflow + # during execution on platforms with default lower precision (bfloat16, float16) + min_dtype = torch.finfo(torch.float16).min + + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + # difference with original modeling + causal_mask = ( + torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype + ) + + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0 +if is_transformers_version(">", "4.40.2"): + _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest +else: + _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy + + class GemmaModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 0745a1cd79..c8b29800fa 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -906,7 +906,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): output_path = os.path.join(output_dir, OV_XML_FILE_NAME) self.compression_controller.prepare_for_export() model_type = self.model.config.model_type.replace("_", "-") - onnx_config_class = TasksManager.get_exporter_config_constructor( + exporter_config_class = TasksManager.get_exporter_config_constructor( exporter="onnx", model=self.model, task=self.task, @@ -914,9 +914,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): ) if self.task == "text-generation": - onnx_config = onnx_config_class(self.model.config, use_past=self.model.config.use_cache) + onnx_config = exporter_config_class(self.model.config, use_past=self.model.config.use_cache) else: - onnx_config = onnx_config_class(self.model.config) + onnx_config = exporter_config_class(self.model.config) num_parameters = self.model.num_parameters() save_as_external_data = use_external_data_format(num_parameters) or self.ov_config.save_onnx_model