Skip to content

Commit dac8645

Browse files
authored
adds debug options to dump onnx graphs (#1789)
add debug options
1 parent 35a81dc commit dac8645

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

optimum/onnxruntime/trainer.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,14 @@ def _inner_training_loop(
459459

460460
# Wrap the model with `ORTModule`
461461
logger.info("Wrap ORTModule for ONNX Runtime training.")
462-
model = ORTModule(self.model)
462+
if self.args.save_onnx:
463+
from torch_ort import DebugOptions
464+
465+
model = ORTModule(
466+
self.model, DebugOptions(save_onnx=self.args.save_onnx, onnx_prefix=self.args.onnx_prefix)
467+
)
468+
else:
469+
model = ORTModule(self.model)
463470
self.model_wrapped = model
464471
self.model = model
465472

optimum/onnxruntime/training_args.py

+30
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,29 @@ class ORTTrainingArguments(TrainingArguments):
7979
},
8080
)
8181

82+
save_onnx: Optional[bool] = field(
83+
default=False,
84+
metadata={
85+
"help": "Configure ORTModule to save onnx models. Defaults to False. \
86+
The output directory of the onnx models by default is set to args.output_dir. \
87+
To change the output directory, the environment variable ORTMODULE_SAVE_ONNX_PATH can be \
88+
set to the destination directory path."
89+
},
90+
)
91+
92+
onnx_prefix: Optional[str] = field(
93+
default=None,
94+
metadata={"help": "Prefix for the saved ORTModule file names. Must be provided if save_onnx is True."},
95+
)
96+
97+
onnx_log_level: Optional[str] = field(
98+
default="WARNING",
99+
metadata={
100+
"help": "Configure ORTModule log level. Defaults to WARNING. \
101+
onnx_log_level can also be set to one of VERBOSE, INFO, WARNING, ERROR, FATAL."
102+
},
103+
)
104+
82105
# This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 🤗 Transformers.
83106
def __post_init__(self):
84107
# expand paths, if not os.makedirs("~/bar") will make directory
@@ -244,6 +267,13 @@ def __post_init__(self):
244267
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
245268
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")
246269

270+
if self.save_onnx:
271+
if not self.onnx_prefix:
272+
raise ValueError("onnx_prefix must be provided if save_onnx is True")
273+
if not os.getenv("ORTMODULE_SAVE_ONNX_PATH", None):
274+
os.environ["ORTMODULE_SAVE_ONNX_PATH"] = self.output_dir
275+
os.environ["ORTMODULE_LOG_LEVEL"] = self.onnx_log_level
276+
247277
if (
248278
is_torch_available()
249279
and (self.device.type != "cuda")

0 commit comments

Comments
 (0)