Skip to content

Commit 45c1c09

Browse files
authored
BF16 support in the ONNX export (#1654)
export work, ort does not
1 parent 843d3f4 commit 45c1c09

File tree

5 files changed

+54
-8
lines changed

5 files changed

+54
-8
lines changed

optimum/commands/export/onnx.py

+8
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def parse_args_onnx(parser):
6262
action="store_true",
6363
help="Use half precision during the export. PyTorch-only, requires `--device cuda`.",
6464
)
65+
optional_group.add_argument(
66+
"--dtype",
67+
type=str,
68+
default=None,
69+
choices=["fp32", "fp16", "bf16"],
70+
help="The floating point precision to use for the export. Supported options: fp32 (float32), fp16 (float16), bf16 (bfloat16).",
71+
)
6572
optional_group.add_argument(
6673
"--optimize",
6774
type=str,
@@ -253,6 +260,7 @@ def run(self):
253260
opset=self.args.opset,
254261
device=self.args.device,
255262
fp16=self.args.fp16,
263+
dtype=self.args.dtype,
256264
optimize=self.args.optimize,
257265
monolith=self.args.monolith,
258266
no_post_process=self.args.no_post_process,

optimum/exporters/onnx/__main__.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def main_export(
167167
task: str = "auto",
168168
opset: Optional[int] = None,
169169
device: str = "cpu",
170+
dtype: Optional[str] = None,
170171
fp16: Optional[bool] = False,
171172
optimize: Optional[str] = None,
172173
monolith: bool = False,
@@ -216,6 +217,8 @@ def main_export(
216217
The device to use to do the export. Defaults to "cpu".
217218
fp16 (`Optional[bool]`, defaults to `"False"`):
218219
Use half precision during the export. PyTorch-only, requires `device="cuda"`.
220+
dtype (`Optional[str]`, defaults to `None`):
221+
The floating point precision to use for the export. Supported options: `"fp32"` (float32), `"fp16"` (float16), `"bf16"` (bfloat16). Defaults to `"fp32"`.
219222
optimize (`Optional[str]`, defaults to `None`):
220223
Allows to run ONNX Runtime optimizations directly during the export. Some of these optimizations are specific to
221224
ONNX Runtime, and the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT.
@@ -283,16 +286,31 @@ def main_export(
283286
>>> main_export("gpt2", output="gpt2_onnx/")
284287
```
285288
"""
289+
290+
if fp16:
291+
if dtype is not None:
292+
raise ValueError(
293+
f'Both the arguments `fp16` ({fp16}) and `dtype` ({dtype}) were specified in the ONNX export, which is not supported. Please specify only `dtype`. Possible options: "fp32" (default), "fp16", "bf16".'
294+
)
295+
296+
logger.warning(
297+
'The argument `fp16` is deprecated in the ONNX export. Please use the argument `dtype="fp16"` instead, or `--dtype fp16` from the command-line.'
298+
)
299+
300+
dtype = "fp16"
301+
elif dtype is None:
302+
dtype = "fp32" # Defaults to float32.
303+
286304
if optimize == "O4" and device != "cuda":
287305
raise ValueError(
288306
"Requested O4 optimization, but this optimization requires to do the export on GPU."
289307
" Please pass the argument `--device cuda`."
290308
)
291309

292-
if (framework == "tf" and fp16 is True) or not is_torch_available():
310+
if (framework == "tf" and fp16) or not is_torch_available():
293311
raise ValueError("The --fp16 option is supported only for PyTorch.")
294312

295-
if fp16 and device == "cpu":
313+
if dtype == "fp16" and device == "cpu":
296314
raise ValueError(
297315
"FP16 export is supported only when exporting on GPU. Please pass the option `--device cuda`."
298316
)
@@ -311,7 +329,13 @@ def main_export(
311329
library_name = TasksManager.infer_library_from_model(
312330
model_name_or_path, subfolder=subfolder, library_name=library_name
313331
)
314-
torch_dtype = None if fp16 is False else torch.float16
332+
333+
torch_dtype = None
334+
if framework == "pt":
335+
if dtype == "fp16":
336+
torch_dtype = torch.float16
337+
elif dtype == "bf16":
338+
torch_dtype = torch.bfloat16
315339

316340
if task.endswith("-with-past") and monolith:
317341
task_non_past = task.replace("-with-past", "")
@@ -479,8 +503,16 @@ def onnx_export(
479503
):
480504
library_name = TasksManager._infer_library_from_model(model)
481505
framework = "pt" if is_torch_available() and isinstance(model, torch.nn.Module) else "tf"
506+
482507
dtype = get_parameter_dtype(model) if framework == "pt" else model.dtype
483-
float_dtype = "fp16" if "float16" in str(dtype) else "fp32"
508+
509+
if "bfloat16" in str(dtype):
510+
float_dtype = "bf16"
511+
elif "float16" in str(dtype):
512+
float_dtype = "fp16"
513+
else:
514+
float_dtype = "fp32"
515+
484516
model_type = "stable-diffusion" if library_name == "diffusers" else model.config.model_type.replace("_", "-")
485517
custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE
486518
task = TasksManager.map_from_synonym(task)
@@ -615,14 +647,19 @@ def onnx_export(
615647

616648
model.save_config(output)
617649

650+
if float_dtype == "bf16":
651+
logger.warning(
652+
f"Exporting the model {model.__class__.__name__} in bfloat16 float dtype. After the export, ONNX Runtime InferenceSession with CPU/CUDA execution provider likely does not implement all operators for the bfloat16 data type, and the loading is likely to fail."
653+
)
654+
618655
_, onnx_outputs = export_models(
619656
models_and_onnx_configs=models_and_onnx_configs,
620657
opset=opset,
621658
output_dir=output,
622659
output_names=onnx_files_subpaths,
623660
input_shapes=input_shapes,
624661
device=device,
625-
dtype="fp16" if float_dtype == "fp16" else None,
662+
dtype=float_dtype,
626663
no_dynamic_axes=no_dynamic_axes,
627664
model_kwargs=model_kwargs,
628665
)

optimum/onnxruntime/modeling_decoder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def __init__(
151151

152152
self.use_fp16 = False
153153
for inp in model.get_inputs():
154-
if (inp.name == "past_key_values" or inp.name in self.key_value_input_names) and inp.type == "tensor(float16)":
154+
if (
155+
inp.name == "past_key_values" or inp.name in self.key_value_input_names
156+
) and inp.type == "tensor(float16)":
155157
self.use_fp16 = True
156158
break
157159

optimum/onnxruntime/modeling_ort.py

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""ORTModelForXXX classes, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers."""
1515

1616
import logging
17-
import math
1817
import re
1918
import shutil
2019
from pathlib import Path

tests/onnxruntime/test_modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3378,7 +3378,7 @@ def test_compare_to_io_binding(self, model_arch):
33783378
self.assertIsInstance(io_outputs.logits, torch.Tensor)
33793379

33803380
# compare tensor outputs
3381-
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), io_outputs.logits, atol=1e-1))
3381+
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), io_outputs.logits, atol=1e-1))
33823382

33833383
gc.collect()
33843384

0 commit comments

Comments
 (0)