You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: optimum/exporters/onnx/__main__.py
+42-5
Original file line number
Diff line number
Diff line change
@@ -167,6 +167,7 @@ def main_export(
167
167
task: str="auto",
168
168
opset: Optional[int] =None,
169
169
device: str="cpu",
170
+
dtype: Optional[str] =None,
170
171
fp16: Optional[bool] =False,
171
172
optimize: Optional[str] =None,
172
173
monolith: bool=False,
@@ -216,6 +217,8 @@ def main_export(
216
217
The device to use to do the export. Defaults to "cpu".
217
218
fp16 (`Optional[bool]`, defaults to `"False"`):
218
219
Use half precision during the export. PyTorch-only, requires `device="cuda"`.
220
+
dtype (`Optional[str]`, defaults to `None`):
221
+
The floating point precision to use for the export. Supported options: `"fp32"` (float32), `"fp16"` (float16), `"bf16"` (bfloat16). Defaults to `"fp32"`.
219
222
optimize (`Optional[str]`, defaults to `None`):
220
223
Allows to run ONNX Runtime optimizations directly during the export. Some of these optimizations are specific to
221
224
ONNX Runtime, and the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT.
@@ -283,16 +286,31 @@ def main_export(
283
286
>>> main_export("gpt2", output="gpt2_onnx/")
284
287
```
285
288
"""
289
+
290
+
iffp16:
291
+
ifdtypeisnotNone:
292
+
raiseValueError(
293
+
f'Both the arguments `fp16` ({fp16}) and `dtype` ({dtype}) were specified in the ONNX export, which is not supported. Please specify only `dtype`. Possible options: "fp32" (default), "fp16", "bf16".'
294
+
)
295
+
296
+
logger.warning(
297
+
'The argument `fp16` is deprecated in the ONNX export. Please use the argument `dtype="fp16"` instead, or `--dtype fp16` from the command-line.'
298
+
)
299
+
300
+
dtype="fp16"
301
+
elifdtypeisNone:
302
+
dtype="fp32"# Defaults to float32.
303
+
286
304
ifoptimize=="O4"anddevice!="cuda":
287
305
raiseValueError(
288
306
"Requested O4 optimization, but this optimization requires to do the export on GPU."
289
307
" Please pass the argument `--device cuda`."
290
308
)
291
309
292
-
if (framework=="tf"andfp16isTrue) ornotis_torch_available():
310
+
if (framework=="tf"andfp16) ornotis_torch_available():
293
311
raiseValueError("The --fp16 option is supported only for PyTorch.")
294
312
295
-
iffp16anddevice=="cpu":
313
+
ifdtype=="fp16"anddevice=="cpu":
296
314
raiseValueError(
297
315
"FP16 export is supported only when exporting on GPU. Please pass the option `--device cuda`."
f"Exporting the model {model.__class__.__name__} in bfloat16 float dtype. After the export, ONNX Runtime InferenceSession with CPU/CUDA execution provider likely does not implement all operators for the bfloat16 data type, and the loading is likely to fail."
0 commit comments