Skip to content

Commit c05ab93

Browse files
Add option to disable ONNX constant folding (#1682)
* optionally disable onnx constant folding * Update optimum/exporters/onnx/__main__.py Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> --------- Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
1 parent e0e12ed commit c05ab93

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

optimum/commands/export/onnx.py

+6
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ def parse_args_onnx(parser):
160160
optional_group.add_argument(
161161
"--no-dynamic-axes", action="store_true", help="Disable dynamic axes during ONNX export"
162162
)
163+
optional_group.add_argument(
164+
"--no-constant-folding",
165+
action="store_true",
166+
help="PyTorch-only argument. Disables PyTorch ONNX export constant folding.",
167+
)
163168

164169
input_group = parser.add_argument_group(
165170
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
@@ -276,5 +281,6 @@ def run(self):
276281
legacy=self.args.legacy,
277282
no_dynamic_axes=self.args.no_dynamic_axes,
278283
model_kwargs=self.args.model_kwargs,
284+
do_constant_folding=not self.args.no_constant_folding,
279285
**input_shapes,
280286
)

optimum/exporters/onnx/__main__.py

+7
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def main_export(
191191
library_name: Optional[str] = None,
192192
legacy: bool = False,
193193
no_dynamic_axes: bool = False,
194+
do_constant_folding: bool = True,
194195
**kwargs_shapes,
195196
):
196197
"""
@@ -275,6 +276,8 @@ def main_export(
275276
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
276277
no_dynamic_axes (bool, defaults to `False`):
277278
If True, disables the use of dynamic axes during ONNX export.
279+
do_constant_folding (bool, defaults to `True`):
280+
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
278281
**kwargs_shapes (`Dict`):
279282
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
280283
@@ -485,6 +488,7 @@ def main_export(
485488
no_dynamic_axes=no_dynamic_axes,
486489
task=task,
487490
use_subprocess=use_subprocess,
491+
do_constant_folding=do_constant_folding,
488492
**kwargs_shapes,
489493
)
490494

@@ -508,6 +512,7 @@ def onnx_export(
508512
no_dynamic_axes: bool = False,
509513
task: Optional[str] = None,
510514
use_subprocess: bool = False,
515+
do_constant_folding: bool = True,
511516
**kwargs_shapes,
512517
):
513518
library_name = TasksManager._infer_library_from_model(model)
@@ -676,6 +681,7 @@ def onnx_export(
676681
device=device,
677682
dtype=float_dtype,
678683
no_dynamic_axes=no_dynamic_axes,
684+
do_constant_folding=do_constant_folding,
679685
model_kwargs=model_kwargs,
680686
)
681687

@@ -775,6 +781,7 @@ def main():
775781
for_ort=args.for_ort,
776782
library_name=args.library_name,
777783
legacy=args.legacy,
784+
do_constant_folding=not args.no_constant_folding,
778785
**input_shapes,
779786
)
780787

optimum/exporters/onnx/convert.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def export_pytorch(
477477
device: str = "cpu",
478478
input_shapes: Optional[Dict] = None,
479479
no_dynamic_axes: bool = False,
480+
do_constant_folding: bool = True,
480481
model_kwargs: Optional[Dict[str, Any]] = None,
481482
) -> Tuple[List[str], List[str]]:
482483
"""
@@ -498,6 +499,8 @@ def export_pytorch(
498499
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
499500
no_dynamic_axes (bool, defaults to `False`):
500501
If True, disables the use of dynamic axes during ONNX export.
502+
do_constant_folding (bool, defaults to `True`):
503+
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
501504
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
502505
Experimental usage: keyword arguments to pass to the model during
503506
the export. This argument should be used along the `custom_onnx_config` argument
@@ -566,7 +569,7 @@ def remap(value):
566569
input_names=input_names,
567570
output_names=output_names,
568571
dynamic_axes=dynamix_axes,
569-
do_constant_folding=True,
572+
do_constant_folding=do_constant_folding,
570573
opset_version=opset,
571574
)
572575

@@ -690,6 +693,7 @@ def export_models(
690693
disable_dynamic_axes_fix: Optional[bool] = False,
691694
dtype: Optional[str] = None,
692695
no_dynamic_axes: bool = False,
696+
do_constant_folding: bool = True,
693697
model_kwargs: Optional[Dict[str, Any]] = None,
694698
) -> Tuple[List[List[str]], List[List[str]]]:
695699
"""
@@ -718,6 +722,8 @@ def export_models(
718722
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
719723
no_dynamic_axes (bool, defaults to `False`):
720724
If True, disables the use of dynamic axes during ONNX export.
725+
do_constant_folding (bool, defaults to `True`):
726+
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
721727
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
722728
Experimental usage: keyword arguments to pass to the model during
723729
the export. This argument should be used along the `custom_onnx_config` argument
@@ -752,6 +758,7 @@ def export_models(
752758
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
753759
dtype=dtype,
754760
no_dynamic_axes=no_dynamic_axes,
761+
do_constant_folding=do_constant_folding,
755762
model_kwargs=model_kwargs,
756763
)
757764
)
@@ -770,6 +777,7 @@ def export(
770777
disable_dynamic_axes_fix: Optional[bool] = False,
771778
dtype: Optional[str] = None,
772779
no_dynamic_axes: bool = False,
780+
do_constant_folding: bool = True,
773781
model_kwargs: Optional[Dict[str, Any]] = None,
774782
) -> Tuple[List[str], List[str]]:
775783
"""
@@ -795,6 +803,8 @@ def export(
795803
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
796804
no_dynamic_axes (bool, defaults to `False`):
797805
If True, disables the use of dynamic axes during ONNX export.
806+
do_constant_folding (bool, defaults to `True`):
807+
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
798808
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
799809
Experimental usage: keyword arguments to pass to the model during
800810
the export. This argument should be used along the `custom_onnx_config` argument
@@ -851,6 +861,7 @@ def export(
851861
device=device,
852862
input_shapes=input_shapes,
853863
no_dynamic_axes=no_dynamic_axes,
864+
do_constant_folding=do_constant_folding,
854865
model_kwargs=model_kwargs,
855866
)
856867

0 commit comments

Comments
 (0)