Skip to content

Commit fce6c58

Browse files
committed
minor updates
1 parent 8043c23 commit fce6c58

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

optimum/exporters/onnx/__main__.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,17 @@ def main_export(
409409
and getattr(model.config, "pad_token_id", None) is None
410410
)
411411

412-
if needs_pad_token_id and pad_token_id is None:
413-
tok = AutoTokenizer.from_pretrained(model_name_or_path)
414-
pad_token_id = getattr(tok, "pad_token_id", None)
415-
if pad_token_id is None:
416-
raise ValueError(
417-
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
418-
)
419-
model.config.pad_token_id = pad_token_id
412+
if needs_pad_token_id:
413+
if pad_token_id is not None:
414+
model.config.pad_token_id = pad_token_id
415+
else:
416+
tok = AutoTokenizer.from_pretrained(model_name_or_path)
417+
pad_token_id = getattr(tok, "pad_token_id", None)
418+
if pad_token_id is None:
419+
raise ValueError(
420+
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
421+
)
422+
model.config.pad_token_id = pad_token_id
420423

421424
model_type = "stable-diffusion" if "stable-diffusion" in task else model.config.model_type.replace("_", "-")
422425
if (

0 commit comments

Comments
 (0)