|
18 | 18 | import math
|
19 | 19 | import types
|
20 | 20 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
21 |
| -from optimum.exporters.onnx.base import OnnxConfig |
22 | 21 |
|
23 | 22 | import torch
|
24 | 23 | import torch.nn.functional as F
|
25 | 24 | from transformers import PreTrainedModel, TFPreTrainedModel
|
26 | 25 | from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
|
27 | 26 | from transformers.utils import is_tf_available
|
28 | 27 |
|
| 28 | +from optimum.exporters.onnx.base import OnnxConfig |
29 | 29 | from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
|
30 | 30 | from optimum.intel.utils.import_utils import (
|
31 | 31 | _openvino_version,
|
@@ -423,9 +423,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
|
423 | 423 | offset = 0
|
424 | 424 | mask_shape = attention_mask.shape
|
425 | 425 | mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
426 |
| - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( |
427 |
| - mask_slice |
428 |
| - ) |
| 426 | + causal_mask[ |
| 427 | + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] |
| 428 | + ] = mask_slice |
429 | 429 |
|
430 | 430 | if (
|
431 | 431 | self.config._attn_implementation == "sdpa"
|
@@ -2060,9 +2060,9 @@ def _dbrx_update_causal_mask_legacy(
|
2060 | 2060 | offset = 0
|
2061 | 2061 | mask_shape = attention_mask.shape
|
2062 | 2062 | mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
2063 |
| - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( |
2064 |
| - mask_slice |
2065 |
| - ) |
| 2063 | + causal_mask[ |
| 2064 | + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] |
| 2065 | + ] = mask_slice |
2066 | 2066 |
|
2067 | 2067 | if (
|
2068 | 2068 | self.config._attn_implementation == "sdpa"
|
@@ -3386,10 +3386,9 @@ class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
|
3386 | 3386 | def __init__(
|
3387 | 3387 | self,
|
3388 | 3388 | config: OnnxConfig,
|
3389 |
| - model: PreTrainedModel | TFPreTrainedModel, |
3390 |
| - model_kwargs: Dict[str, Any] | None = None, |
| 3389 | + model: Union[PreTrainedModel, TFPreTrainedModel], |
| 3390 | + model_kwargs: Dict[str, Any] = None, |
3391 | 3391 | ):
|
3392 |
| - |
3393 | 3392 | model.__orig_forward = model.forward
|
3394 | 3393 |
|
3395 | 3394 | def forward_wrap(
|
@@ -3426,8 +3425,8 @@ class Qwen2VLVisionEmbMergerPatcher(ModelPatcher):
|
3426 | 3425 | def __init__(
|
3427 | 3426 | self,
|
3428 | 3427 | config: OnnxConfig,
|
3429 |
| - model: PreTrainedModel | TFPreTrainedModel, |
3430 |
| - model_kwargs: Dict[str, Any] | None = None, |
| 3428 | + model: Union[PreTrainedModel, TFPreTrainedModel], |
| 3429 | + model_kwargs: Dict[str, Any] = None, |
3431 | 3430 | ):
|
3432 | 3431 | model.__orig_forward = model.forward
|
3433 | 3432 |
|
|
0 commit comments