Skip to content

Commit f8f9707

Browse files
Support transformers 4.43 (#1971)
* fix bt bark test * setup * patch clip models for sd * infer ort model dtype property from inputs dtypes * patch all clip variants * device setter * bigger model for now * fix device attribution * onnx opset for owlvit and owlv2 * model dtype * revert * use model part dtype instead * no need for dtype with diffusion pipelines * revert * fix clip text model with projection not outputting hidden states * whisper generation * fix whisper, support cache_position, and using transformers whisper generation loop * style * create cache position for merged decoder and fix test for non whisper speech to text * typo * conditioned cache position argument * update whisper min transformers version * compare whisper ort generation with transformers * fix generation length for speech to text model type * cache position in whisper only with dynamic axis decoder_sequence_length * use minimal prepare_inputs_for_generation in ORTModelForSpeechSeq2Seq * remove version restrictions on whisper * comment * fix * simpler --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 2a6d857 commit f8f9707

File tree

12 files changed

+269
-495
lines changed

12 files changed

+269
-495
lines changed

optimum/exporters/onnx/config.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
289289
if self._behavior is not ConfigBehavior.ENCODER:
290290
if self.use_past_in_inputs:
291291
common_inputs["decoder_input_ids"] = {0: "batch_size"}
292+
self.add_past_key_values(common_inputs, direction="inputs")
292293
else:
293294
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
294295

295-
if self.use_past_in_inputs:
296-
self.add_past_key_values(common_inputs, direction="inputs")
297-
298296
if self._behavior is ConfigBehavior.DECODER:
299297
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
300298

optimum/exporters/onnx/model_configs.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
NormalizedTextConfig,
5454
NormalizedTextConfigWithGQA,
5555
NormalizedVisionConfig,
56+
check_if_transformers_greater,
5657
is_diffusers_available,
5758
logging,
5859
)
@@ -71,6 +72,7 @@
7172
)
7273
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
7374
from .model_patcher import (
75+
CLIPModelPatcher,
7476
FalconModelPatcher,
7577
MistralModelPatcher,
7678
MusicgenModelPatcher,
@@ -913,10 +915,16 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
913915

914916
return common_outputs
915917

918+
def patch_model_for_export(
919+
self,
920+
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
921+
model_kwargs: Optional[Dict[str, Any]] = None,
922+
) -> "ModelPatcher":
923+
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)
924+
916925

917926
class CLIPOnnxConfig(TextAndVisionOnnxConfig):
918927
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
919-
DEFAULT_ONNX_OPSET = 14
920928

921929
@property
922930
def inputs(self) -> Dict[str, Dict[int, str]]:
@@ -935,6 +943,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
935943
"image_embeds": {0: "image_batch_size"},
936944
}
937945

946+
def patch_model_for_export(
947+
self,
948+
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
949+
model_kwargs: Optional[Dict[str, Any]] = None,
950+
) -> "ModelPatcher":
951+
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)
952+
938953

939954
class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
940955
@property
@@ -980,6 +995,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
980995

981996
return common_outputs
982997

998+
def patch_model_for_export(
999+
self,
1000+
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
1001+
model_kwargs: Optional[Dict[str, Any]] = None,
1002+
) -> "ModelPatcher":
1003+
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)
1004+
9831005

9841006
class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
9851007
@property
@@ -997,12 +1019,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
9971019
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
9981020
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
9991021

1022+
# TODO: fix should be by casting inputs during inference and not export
10001023
if framework == "pt":
10011024
import torch
10021025

10031026
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
10041027
return dummy_inputs
10051028

1029+
def patch_model_for_export(
1030+
self,
1031+
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
1032+
model_kwargs: Optional[Dict[str, Any]] = None,
1033+
) -> "ModelPatcher":
1034+
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)
1035+
10061036

10071037
class UNetOnnxConfig(VisionOnnxConfig):
10081038
ATOL_FOR_VALIDATION = 1e-3
@@ -1135,6 +1165,9 @@ class OwlViTOnnxConfig(CLIPOnnxConfig):
11351165
ATOL_FOR_VALIDATION = 1e-4
11361166
MIN_TORCH_VERSION = version.parse("2.1")
11371167

1168+
# needs einsum operator support, available since opset 12
1169+
DEFAULT_ONNX_OPSET = 12
1170+
11381171
def __init__(
11391172
self,
11401173
config: "PretrainedConfig",
@@ -1438,7 +1471,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
14381471
if self._behavior is not ConfigBehavior.DECODER:
14391472
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.
14401473

1441-
if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
1474+
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
1475+
if check_if_transformers_greater("4.43.0"):
1476+
# since https://github.com/huggingface/transformers/pull/31166
1477+
common_inputs["cache_position"] = {0: "decoder_sequence_length"}
1478+
1479+
if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
14421480
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
14431481
return common_inputs
14441482

optimum/exporters/onnx/model_patcher.py

+17
Original file line numberDiff line numberDiff line change
@@ -1138,3 +1138,20 @@ def __init__(
11381138
self._update_causal_mask_original = self._model.model._update_causal_mask
11391139
else:
11401140
self._update_causal_mask_original = self._model._update_causal_mask
1141+
1142+
1143+
class CLIPModelPatcher(ModelPatcher):
1144+
def __enter__(self):
1145+
super().__enter__()
1146+
1147+
if _transformers_version >= version.parse("4.43"):
1148+
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention
1149+
1150+
self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward
1151+
1152+
def __exit__(self, exc_type, exc_value, traceback):
1153+
super().__exit__(exc_type, exc_value, traceback)
1154+
if _transformers_version >= version.parse("4.43"):
1155+
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
1156+
1157+
CLIPSdpaAttention.forward = self.original_sdpa_forward

optimum/exporters/utils.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _get_submodels_for_export_diffusion(
9696
pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline)
9797
)
9898
is_stable_diffusion_xl = isinstance(
99-
pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline)
99+
pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline)
100100
)
101101
is_latent_consistency_model = isinstance(
102102
pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline)
@@ -117,10 +117,11 @@ def _get_submodels_for_export_diffusion(
117117
models_for_export = {}
118118

119119
# Text encoder
120-
if pipeline.text_encoder is not None:
120+
text_encoder = getattr(pipeline, "text_encoder", None)
121+
if text_encoder is not None:
121122
if is_stable_diffusion_xl:
122-
pipeline.text_encoder.config.output_hidden_states = True
123-
models_for_export["text_encoder"] = pipeline.text_encoder
123+
text_encoder.config.output_hidden_states = True
124+
models_for_export["text_encoder"] = text_encoder
124125

125126
# U-NET
126127
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
@@ -151,6 +152,7 @@ def _get_submodels_for_export_diffusion(
151152
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
152153
if text_encoder_2 is not None:
153154
text_encoder_2.config.output_hidden_states = True
155+
text_encoder_2.text_model.config.output_hidden_states = True
154156
models_for_export["text_encoder_2"] = text_encoder_2
155157

156158
return models_for_export

optimum/onnxruntime/base.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from ..utils import NormalizedConfigManager
2626
from ..utils.logging import warn_once
27+
from .io_binding import TypeHelper
2728
from .modeling_ort import ORTModel
2829
from .utils import get_ordered_input_names, logging
2930

@@ -62,6 +63,20 @@ def __init__(
6263
def device(self):
6364
return self.parent_model.device
6465

66+
@property
67+
def dtype(self):
68+
for dtype in self.input_dtypes.values():
69+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
70+
if torch_dtype.is_floating_point:
71+
return torch_dtype
72+
73+
for dtype in self.output_dtypes.values():
74+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
75+
if torch_dtype.is_floating_point:
76+
return torch_dtype
77+
78+
return None
79+
6580
@abstractmethod
6681
def forward(self, *args, **kwargs):
6782
pass
@@ -220,6 +235,7 @@ def forward(
220235
encoder_attention_mask: Optional[torch.LongTensor] = None,
221236
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
222237
labels: Optional[torch.LongTensor] = None,
238+
cache_position: Optional[torch.Tensor] = None,
223239
use_cache_branch: None = None,
224240
) -> Seq2SeqLMOutput:
225241
# Adding use_cache_branch in the signature here is just a hack for IO Binding
@@ -236,8 +252,8 @@ def forward(
236252
# no-ops if merged decoder is not used
237253
use_merged_no_cache = past_key_values is None and self.parent_model.use_merged
238254
use_merged_cache = past_key_values is not None and self.parent_model.use_merged
239-
use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged(
240-
input_ids, past_key_values, use_torch=use_torch
255+
use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged(
256+
input_ids, past_key_values, cache_position, use_torch=use_torch
241257
)
242258

243259
if self.parent_model.use_io_binding:
@@ -274,6 +290,9 @@ def forward(
274290
if use_cache_branch_tensor is not None:
275291
model_inputs.append(use_cache_branch_tensor)
276292

293+
if "cache_position" in self.input_names:
294+
model_inputs.append(cache_position)
295+
277296
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
278297
self.session,
279298
*model_inputs,
@@ -346,6 +365,7 @@ def forward(
346365
"decoder_attention_mask": decoder_attention_mask,
347366
"encoder_attention_mask": encoder_attention_mask,
348367
"use_cache_branch": use_cache_branch_tensor,
368+
"cache_position": cache_position,
349369
"labels": labels,
350370
}
351371
if past_key_values is not None:
@@ -405,20 +425,20 @@ def forward(
405425

406426
def prepare_inputs_for_merged(
407427
self,
408-
input_ids: Union[None, torch.LongTensor, np.ndarray],
409-
past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]],
428+
input_ids: Optional[Union[torch.LongTensor, np.ndarray]],
429+
past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]],
430+
cache_position: Optional[Union[torch.Tensor, np.ndarray]],
410431
use_torch: bool,
411432
):
433+
constructor = torch if use_torch is True else np
434+
412435
if self.parent_model.use_merged:
413-
constructor = torch if use_torch is True else np
414436
# Uses without/with branch of a merged decoder depending on whether real past key values are passed
415-
use_cache_branch = constructor.full((1,), past_key_values is not None)
437+
use_cache_branch_tensor = constructor.full((1,), past_key_values is not None)
438+
if use_torch and use_cache_branch_tensor is not None:
439+
use_cache_branch_tensor = use_cache_branch_tensor.to(self.device)
416440
else:
417-
# Uses separate decoders
418-
use_cache_branch = None
419-
420-
if use_torch and use_cache_branch is not None:
421-
use_cache_branch = use_cache_branch.to(self.device)
441+
use_cache_branch_tensor = None
422442

423443
# Generate dummy past for the first forward if uses a merged decoder
424444
if self.parent_model.use_merged and past_key_values is None:
@@ -434,7 +454,13 @@ def prepare_inputs_for_merged(
434454

435455
past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
436456

437-
return use_cache_branch, past_key_values
457+
# Generate dummy position cache for the first forward if uses a merged decoder
458+
if self.parent_model.use_merged and cache_position is None:
459+
cache_position = constructor.zeros((1,), dtype=constructor.int64)
460+
if use_torch is True:
461+
cache_position = cache_position.to(self.device)
462+
463+
return use_cache_branch_tensor, past_key_values, cache_position
438464

439465

440466
class ORTDecoder(ORTDecoderForSeq2Seq):

optimum/onnxruntime/modeling_diffusion.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,14 @@ def to(self, device: Union[torch.device, str, int]):
452452
Returns:
453453
`ORTModel`: the model placed on the requested device.
454454
"""
455+
455456
device, provider_options = parse_device(device)
456457
provider = get_provider_for_device(device)
457458
validate_provider_availability(provider) # raise error if the provider is not available
458-
self.device = device
459+
460+
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
461+
return self
462+
459463
self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])
460464
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])
461465
self.unet.session.set_providers([provider], provider_options=[provider_options])
@@ -464,6 +468,8 @@ def to(self, device: Union[torch.device, str, int]):
464468
self.vae_encoder.session.set_providers([provider], provider_options=[provider_options])
465469

466470
self.providers = self.vae_decoder.session.get_providers()
471+
self._device = device
472+
467473
return self
468474

469475
@classmethod

optimum/onnxruntime/modeling_ort.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,24 @@ def __init__(
276276

277277
self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)
278278

279-
# TODO: why do we make device a property since we are only access the value, and do not do any check when setting the value?
279+
@property
280+
def dtype(self) -> torch.dtype:
281+
"""
282+
`torch.dtype`: The dtype of the model.
283+
"""
284+
285+
for dtype in self.input_dtypes.values():
286+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
287+
if torch_dtype.is_floating_point:
288+
return torch_dtype
289+
290+
for dtype in self.output_dtypes.values():
291+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
292+
if torch_dtype.is_floating_point:
293+
return torch_dtype
294+
295+
return None
296+
280297
@property
281298
def device(self) -> torch.device:
282299
"""
@@ -286,8 +303,8 @@ def device(self) -> torch.device:
286303
return self._device
287304

288305
@device.setter
289-
def device(self, value: torch.device):
290-
self._device = value
306+
def device(self, **kwargs):
307+
raise AttributeError("The device attribute is read-only, please use the `to` method to change the device.")
291308

292309
@property
293310
def use_io_binding(self):
@@ -309,13 +326,13 @@ def to(self, device: Union[torch.device, str, int]):
309326
Returns:
310327
`ORTModel`: the model placed on the requested device.
311328
"""
329+
312330
device, provider_options = parse_device(device)
313331

314332
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
315333
return self
316334

317-
self.device = device
318-
provider = get_provider_for_device(self.device)
335+
provider = get_provider_for_device(device)
319336
validate_provider_availability(provider) # raise error if the provider is not available
320337

321338
# IOBinding is only supported for CPU and CUDA Execution Providers.
@@ -331,6 +348,7 @@ def to(self, device: Union[torch.device, str, int]):
331348

332349
self.model.set_providers([provider], provider_options=[provider_options])
333350
self.providers = self.model.get_providers()
351+
self._device = device
334352

335353
return self
336354

0 commit comments

Comments
 (0)