Skip to content

Commit 3b4f5ac

Browse files
authored
Avoid overriding model_type in TasksManager (#1647)
* avoid modifying model_type * cleanup * fix test * fix test * fix library detection local model * fix merge * make library_name non-optional * fix warning * trigger ci * fix library detection
1 parent 2a789d6 commit 3b4f5ac

File tree

9 files changed

+210
-92
lines changed

9 files changed

+210
-92
lines changed

optimum/commands/export/onnx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def parse_args_onnx(parser):
140140
optional_group.add_argument(
141141
"--library-name",
142142
type=str,
143-
choices=["transformers", "diffusers", "timm"],
143+
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
144144
default=None,
145145
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
146146
)

optimum/exporters/onnx/__main__.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,16 @@ def _get_submodels_and_onnx_configs(
6868
custom_onnx_configs: Dict,
6969
custom_architecture: bool,
7070
_variant: str,
71+
library_name: str,
7172
int_dtype: str = "int64",
7273
float_dtype: str = "fp32",
7374
fn_get_submodels: Optional[Callable] = None,
7475
preprocessors: Optional[List[Any]] = None,
7576
legacy: bool = False,
76-
library_name: str = "transformers",
7777
model_kwargs: Optional[Dict] = None,
7878
):
79-
is_stable_diffusion = "stable-diffusion" in task
8079
if not custom_architecture:
81-
if is_stable_diffusion:
80+
if library_name == "diffusers":
8281
onnx_config = None
8382
models_and_onnx_configs = get_stable_diffusion_models_for_export(
8483
model, int_dtype=int_dtype, float_dtype=float_dtype
@@ -129,7 +128,7 @@ def _get_submodels_and_onnx_configs(
129128
if fn_get_submodels is not None:
130129
submodels_for_export = fn_get_submodels(model)
131130
else:
132-
if is_stable_diffusion:
131+
if library_name == "diffusers":
133132
submodels_for_export = _get_submodels_for_export_stable_diffusion(model)
134133
elif (
135134
model.config.is_encoder_decoder
@@ -373,12 +372,16 @@ def main_export(
373372

374373
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
375374
custom_architecture = True
376-
elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
375+
elif task not in TasksManager.get_supported_tasks_for_model_type(
376+
model_type, "onnx", library_name=library_name
377+
):
377378
if original_task == "auto":
378379
autodetected_message = " (auto-detected)"
379380
else:
380381
autodetected_message = ""
381-
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
382+
model_tasks = TasksManager.get_supported_tasks_for_model_type(
383+
model_type, exporter="onnx", library_name=library_name
384+
)
382385
raise ValueError(
383386
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
384387
)
@@ -422,7 +425,13 @@ def main_export(
422425
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
423426
)
424427

425-
model_type = "stable-diffusion" if "stable-diffusion" in task else model.config.model_type.replace("_", "-")
428+
if "stable-diffusion" in task:
429+
model_type = "stable-diffusion"
430+
elif hasattr(model.config, "export_model_type"):
431+
model_type = model.config.export_model_type.replace("_", "-")
432+
else:
433+
model_type = model.config.model_type.replace("_", "-")
434+
426435
if (
427436
not custom_architecture
428437
and library_name != "diffusers"
@@ -513,14 +522,20 @@ def onnx_export(
513522
else:
514523
float_dtype = "fp32"
515524

516-
model_type = "stable-diffusion" if library_name == "diffusers" else model.config.model_type.replace("_", "-")
525+
if "stable-diffusion" in task:
526+
model_type = "stable-diffusion"
527+
elif hasattr(model.config, "export_model_type"):
528+
model_type = model.config.export_model_type.replace("_", "-")
529+
else:
530+
model_type = model.config.model_type.replace("_", "-")
531+
517532
custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE
518533
task = TasksManager.map_from_synonym(task)
519534

520535
# TODO: support onnx_config.py in the model repo
521536
if custom_architecture and custom_onnx_configs is None:
522537
raise ValueError(
523-
f"Trying to export a {model.config.model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model.config.model_type} to be supported natively in the ONNX export."
538+
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export."
524539
)
525540

526541
if task is None:
@@ -690,7 +705,7 @@ def onnx_export(
690705
if library_name == "diffusers":
691706
# TODO: fix Can't pickle local object 'get_stable_diffusion_models_for_export.<locals>.<lambda>'
692707
use_subprocess = False
693-
elif model.config.model_type in UNPICKABLE_ARCHS:
708+
elif model_type in UNPICKABLE_ARCHS:
694709
# Pickling is bugged for nn.utils.weight_norm: https://github.com/pytorch/pytorch/issues/102983
695710
# TODO: fix "Cowardly refusing to serialize non-leaf tensor" error for wav2vec2-conformer
696711
use_subprocess = False

optimum/exporters/onnx/config.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,10 @@ def __init__(
344344

345345
# Set up the encoder ONNX config.
346346
encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
347-
exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type
347+
exporter="onnx",
348+
task="feature-extraction",
349+
model_type=config.encoder.model_type,
350+
library_name="transformers",
348351
)
349352
self._encoder_onnx_config = encoder_onnx_config_constructor(
350353
config.encoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
@@ -353,7 +356,10 @@ def __init__(
353356

354357
# Set up the decoder ONNX config.
355358
decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
356-
exporter="onnx", task="feature-extraction", model_type=config.decoder.model_type
359+
exporter="onnx",
360+
task="feature-extraction",
361+
model_type=config.decoder.model_type,
362+
library_name="transformers",
357363
)
358364
kwargs = {}
359365
if issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast):

optimum/exporters/onnx/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def get_stable_diffusion_models_for_export(
323323
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
324324
model=pipeline.text_encoder,
325325
exporter="onnx",
326+
library_name="diffusers",
326327
task="feature-extraction",
327328
)
328329
text_encoder_onnx_config = text_encoder_config_constructor(
@@ -334,6 +335,7 @@ def get_stable_diffusion_models_for_export(
334335
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
335336
model=pipeline.unet,
336337
exporter="onnx",
338+
library_name="diffusers",
337339
task="semantic-segmentation",
338340
model_type="unet",
339341
)
@@ -345,6 +347,7 @@ def get_stable_diffusion_models_for_export(
345347
vae_config_constructor = TasksManager.get_exporter_config_constructor(
346348
model=vae_encoder,
347349
exporter="onnx",
350+
library_name="diffusers",
348351
task="semantic-segmentation",
349352
model_type="vae-encoder",
350353
)
@@ -356,6 +359,7 @@ def get_stable_diffusion_models_for_export(
356359
vae_config_constructor = TasksManager.get_exporter_config_constructor(
357360
model=vae_decoder,
358361
exporter="onnx",
362+
library_name="diffusers",
359363
task="semantic-segmentation",
360364
model_type="vae-decoder",
361365
)
@@ -366,6 +370,7 @@ def get_stable_diffusion_models_for_export(
366370
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
367371
model=pipeline.text_encoder_2,
368372
exporter="onnx",
373+
library_name="diffusers",
369374
task="feature-extraction",
370375
model_type="clip-text-with-projection",
371376
)

0 commit comments

Comments
 (0)