@@ -79,6 +79,29 @@ class ORTTrainingArguments(TrainingArguments):
79
79
},
80
80
)
81
81
82
+ save_onnx : Optional [bool ] = field (
83
+ default = False ,
84
+ metadata = {
85
+ "help" : "Configure ORTModule to save onnx models. Defaults to False. \
86
+ The output directory of the onnx models by default is set to args.output_dir. \
87
+ To change the output directory, the environment variable ORTMODULE_SAVE_ONNX_PATH can be \
88
+ set to the destination directory path."
89
+ },
90
+ )
91
+
92
+ onnx_prefix : Optional [str ] = field (
93
+ default = None ,
94
+ metadata = {"help" : "Prefix for the saved ORTModule file names. Must be provided if save_onnx is True." },
95
+ )
96
+
97
+ onnx_log_level : Optional [str ] = field (
98
+ default = "WARNING" ,
99
+ metadata = {
100
+ "help" : "Configure ORTModule log level. Defaults to WARNING. \
101
+ onnx_log_level can also be set to one of VERBOSE, INFO, WARNING, ERROR, FATAL."
102
+ },
103
+ )
104
+
82
105
# This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 🤗 Transformers.
83
106
def __post_init__ (self ):
84
107
# expand paths, if not os.makedirs("~/bar") will make directory
@@ -244,6 +267,13 @@ def __post_init__(self):
244
267
if version .parse (version .parse (torch .__version__ ).base_version ) == version .parse ("2.0.0" ) and self .fp16 :
245
268
raise ValueError ("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0" )
246
269
270
+ if self .save_onnx :
271
+ if not self .onnx_prefix :
272
+ raise ValueError ("onnx_prefix must be provided if save_onnx is True" )
273
+ if not os .getenv ("ORTMODULE_SAVE_ONNX_PATH" , None ):
274
+ os .environ ["ORTMODULE_SAVE_ONNX_PATH" ] = self .output_dir
275
+ os .environ ["ORTMODULE_LOG_LEVEL" ] = self .onnx_log_level
276
+
247
277
if (
248
278
is_torch_available ()
249
279
and (self .device .type != "cuda" )
0 commit comments