Skip to content

Commit 5ae26d0

Browse files
committed
add variant for model loading in from_transformers
1 parent 6b98d62 commit 5ae26d0

10 files changed

+66
-64
lines changed

optimum/exporters/openvino/convert.py

+1-47
Original file line numberDiff line numberDiff line change
@@ -1013,11 +1013,10 @@ def _get_submodels_and_export_configs(
10131013
def get_diffusion_models_for_export_ext(
10141014
pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino"
10151015
):
1016-
<<<<<<< HEAD
10171016
is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
10181017
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
10191018
is_flux = pipeline.__class__.__name__.startswith("Flux")
1020-
is_sana = pipeline.__class__.__name__.startswith("Sana")
1019+
is_sana = pipeline.__class__.__name__.startswith("Sana")
10211020
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3
10221021
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")
10231022

@@ -1036,51 +1035,6 @@ def get_diffusion_models_for_export_ext(
10361035
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
10371036
elif is_flux:
10381037
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1039-
=======
1040-
if is_diffusers_version(">=", "0.29.0"):
1041-
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
1042-
1043-
sd3_pipes = [StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline]
1044-
if is_diffusers_version(">=", "0.30.0"):
1045-
from diffusers import StableDiffusion3InpaintPipeline
1046-
1047-
sd3_pipes.append(StableDiffusion3InpaintPipeline)
1048-
1049-
is_sd3 = isinstance(pipeline, tuple(sd3_pipes))
1050-
else:
1051-
is_sd3 = False
1052-
1053-
if is_diffusers_version(">=", "0.30.0"):
1054-
from diffusers import FluxPipeline
1055-
1056-
flux_pipes = [FluxPipeline]
1057-
1058-
if is_diffusers_version(">=", "0.31.0"):
1059-
from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline
1060-
1061-
flux_pipes.extend([FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline])
1062-
1063-
if is_diffusers_version(">=", "0.32.0"):
1064-
from diffusers import FluxFillPipeline
1065-
1066-
flux_pipes.append(FluxFillPipeline)
1067-
1068-
is_flux = isinstance(pipeline, tuple(flux_pipes))
1069-
else:
1070-
is_flux = False
1071-
1072-
if is_diffusers_version(">=", "0.32.0"):
1073-
from diffusers import SanaPipeline
1074-
1075-
is_sana = isinstance(pipeline, SanaPipeline)
1076-
else:
1077-
is_sana = False
1078-
1079-
if not any([is_sana, is_flux, is_sd3]):
1080-
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
1081-
if is_sd3:
1082-
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1083-
>>>>>>> add pipeline
10841038
elif is_sana:
10851039
models_for_export = get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype)
10861040
else:

optimum/exporters/openvino/model_configs.py

+6
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
Qwen2VLVisionEmbMergerPatcher,
107107
QwenModelPatcher,
108108
RotaryEmbPatcher,
109+
SanaTextEncoderModelPatcher,
109110
StatefulSeq2SeqDecoderPatcher,
110111
UpdateCausalMaskModelPatcher,
111112
XverseModelPatcher,
@@ -1903,6 +1904,11 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
19031904
"attention_mask": {0: "batch_size", 1: "sequence_length"},
19041905
}
19051906

1907+
def patch_model_for_export(
1908+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1909+
) -> ModelPatcher:
1910+
return SanaTextEncoderModelPatcher(self, model, model_kwargs)
1911+
19061912

19071913
class DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
19081914
SUPPORTED_INPUT_NAMES = (

optimum/exporters/openvino/model_patcher.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
import torch
2323
import torch.nn.functional as F
24+
from transformers import PreTrainedModel, TFPreTrainedModel
2425
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
2526
from transformers.utils import is_tf_available
2627

28+
from optimum.exporters.onnx.base import OnnxConfig
2729
from optimum.exporters.onnx.model_patcher import (
2830
DecoderModelPatcher,
2931
ModelPatcher,
@@ -114,18 +116,20 @@ def patch_model_with_bettertransformer(model):
114116
return model
115117

116118

117-
def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None):
119+
def patch_update_causal_mask(
120+
model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False
121+
):
118122
if is_transformers_version(">=", transformers_version):
119-
inner_model = getattr(model, inner_model_name, None)
123+
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
120124
if inner_model is not None:
121125
if hasattr(inner_model, "_update_causal_mask"):
122126
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
123127
patch_fn = patch_fn or _llama_gemma_update_causal_mask
124128
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)
125129

126130

127-
def unpatch_update_causal_mask(model, inner_model_name="model"):
128-
inner_model = getattr(model, inner_model_name, None)
131+
def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False):
132+
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
129133
if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"):
130134
inner_model._update_causal_mask = inner_model._orig_update_causal_mask
131135

@@ -3791,3 +3795,29 @@ def patched_forward(*args, **kwargs):
37913795
model.forward = patched_forward
37923796

37933797
super().__init__(config, model, model_kwargs)
3798+
3799+
3800+
class SanaTextEncoderModelPatcher(ModelPatcher):
3801+
def __enter__(self):
3802+
super().__enter__()
3803+
patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True)
3804+
3805+
if self._model.config._attn_implementation != "sdpa":
3806+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3807+
self._model.config._attn_implementation = "sdpa"
3808+
if is_transformers_version("<", "4.47.0"):
3809+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
3810+
3811+
sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
3812+
for layer in self._model.layers:
3813+
layer.self_attn._orig_forward = layer.self_attn.forward
3814+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
3815+
3816+
def __exit__(self, exc_type, exc_value, traceback):
3817+
super().__exit__(exc_type, exc_value, traceback)
3818+
unpatch_update_causal_mask(self._model, None, True)
3819+
if hasattr(self._model.config, "_orig_attn_implementation"):
3820+
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
3821+
for layer in self._model.layers:
3822+
if hasattr(layer.self_attn, "_orig_forward"):
3823+
layer.self_attn.forward = layer.self_attn._orig_forward

optimum/exporters/openvino/utils.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,15 @@ def deduce_diffusers_dtype(model_name_or_path, **loading_kwargs):
257257
model_part_name = "unet"
258258
if model_part_name:
259259
directory = path / model_part_name
260-
safetensors_files = [
261-
filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1
262-
]
260+
261+
pattern = "*.safetensors"
262+
if "variant" in loading_kwargs:
263+
variant = loading_kwargs["variant"]
264+
pattern = f"*.{variant}.safetensors"
265+
safetensors_files = list(directory.glob(pattern))
266+
else:
267+
# filter out variant files
268+
safetensors_files = [filename for filename in directory.glob(pattern) if len(filename.suffixes) == 1]
263269
safetensors_file = None
264270
if len(safetensors_files) > 0:
265271
safetensors_file = safetensors_files.pop(0)

optimum/intel/openvino/modeling_base.py

+3
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,8 @@ def _from_transformers(
594594
else:
595595
ov_config = OVConfig(dtype="fp32")
596596

597+
variant = kwargs.pop("variant", None)
598+
597599
main_export(
598600
model_name_or_path=model_id,
599601
output=save_dir_path,
@@ -607,6 +609,7 @@ def _from_transformers(
607609
trust_remote_code=trust_remote_code,
608610
ov_config=ov_config,
609611
library_name=cls._library_name,
612+
model_variant=variant,
610613
)
611614

612615
return cls._from_pretrained(

optimum/intel/openvino/modeling_base_seq2seq.py

+2
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def _from_transformers(
408408
else:
409409
ov_config = OVConfig(dtype="fp32")
410410
stateful = kwargs.get("stateful", True)
411+
variant = kwargs.pop("variant", None)
411412

412413
main_export(
413414
model_name_or_path=model_id,
@@ -422,6 +423,7 @@ def _from_transformers(
422423
trust_remote_code=trust_remote_code,
423424
ov_config=ov_config,
424425
stateful=stateful,
426+
model_variant=variant,
425427
)
426428

427429
return cls._from_pretrained(

optimum/intel/openvino/modeling_decoder.py

+3
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def _from_transformers(
310310
if torch_dtype is not None:
311311
model_loading_kwargs["torch_dtype"] = torch_dtype
312312

313+
variant = kwargs.pop("variant", None)
314+
313315
main_export(
314316
model_name_or_path=model_id,
315317
output=save_dir_path,
@@ -325,6 +327,7 @@ def _from_transformers(
325327
stateful=stateful,
326328
model_loading_kwargs=model_loading_kwargs,
327329
library_name=cls._library_name,
330+
model_variant=variant,
328331
)
329332

330333
if config.model_type == "phi3" and config.max_position_embeddings != getattr(

optimum/intel/openvino/modeling_diffusion.py

+2
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def _from_transformers(
575575

576576
model_save_dir = TemporaryDirectory()
577577
model_save_path = Path(model_save_dir.name)
578+
variant = kwargs.pop("variant", None)
578579

579580
main_export(
580581
model_name_or_path=model_id,
@@ -589,6 +590,7 @@ def _from_transformers(
589590
force_download=force_download,
590591
ov_config=ov_config,
591592
library_name=cls._library_name,
593+
model_variant=variant,
592594
)
593595

594596
return cls._from_pretrained(

optimum/intel/openvino/modeling_visual_language.py

+2
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ def _from_transformers(
615615
ov_config = OVConfig(dtype="fp32" if load_in_8bit is False else "auto")
616616

617617
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
618+
variant = kwargs.pop("variant", None)
618619

619620
main_export(
620621
model_name_or_path=model_id,
@@ -629,6 +630,7 @@ def _from_transformers(
629630
trust_remote_code=trust_remote_code,
630631
ov_config=ov_config,
631632
stateful=stateful,
633+
model_variant=variant,
632634
)
633635
config = AutoConfig.from_pretrained(save_dir_path, trust_remote_code=trust_remote_code)
634636
return cls._from_pretrained(

tests/openvino/test_diffusion.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
149149
for output_type in ["latent", "np", "pt"]:
150150
inputs["output_type"] = output_type
151151
if model_arch == "sana":
152-
if output_type == "latent":
153-
continue
152+
# resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
154153
inputs["use_resolution_binning"] = False
155-
atol = 4e-2
156-
else:
157-
atol = 6e-3
154+
atol = 1e-4
158155

159156
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
160157
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
@@ -166,12 +163,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
166163
for output_type in ["latent", "np", "pt"]:
167164
inputs["output_type"] = output_type
168165
if model_arch == "sana":
169-
if output_type == "latent":
170-
continue
166+
# resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
171167
inputs["use_resolution_binning"] = False
172-
atol = 4e-2
173-
else:
174-
atol = 6e-3
168+
atol = 6e-3
175169

176170
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
177171
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

0 commit comments

Comments
 (0)