Skip to content

Commit 400bb82

Browse files
authored
[fix] Allow ORTQuantizer over models with subfolder ONNX files (#2094)
* Allow ORTQuantizer over models with subfolder ONNX files * Also catch ValueError as that seems a common fail when AutoConfig.from_pretrained("does/not/exist") * Use test case that previously failed
1 parent c513437 commit 400bb82

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

optimum/onnxruntime/quantization.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(self, onnx_model_path: Path, config: Optional["PretrainedConfig"] =
100100
if self.config is None:
101101
try:
102102
self.config = AutoConfig.from_pretrained(self.onnx_model_path.parent)
103-
except OSError:
103+
except (OSError, ValueError):
104104
LOGGER.warning(
105105
f"Could not load the config for {self.onnx_model_path} automatically, this might make "
106106
"the quantized model harder to use because it will not be able to be loaded by an ORTModel without "
@@ -134,6 +134,7 @@ def from_pretrained(
134134
model_or_path = Path(model_or_path)
135135

136136
path = None
137+
config = None
137138
if isinstance(model_or_path, ORTModelForConditionalGeneration):
138139
raise NotImplementedError(ort_quantizer_error_message)
139140
elif isinstance(model_or_path, Path) and file_name is None:
@@ -147,13 +148,13 @@ def from_pretrained(
147148
file_name = onnx_files[0].name
148149

149150
if isinstance(model_or_path, ORTModel):
150-
if path is None:
151-
path = Path(model_or_path.model._model_path)
151+
path = Path(model_or_path.model._model_path)
152+
config = model_or_path.config
152153
elif os.path.isdir(model_or_path):
153154
path = Path(model_or_path) / file_name
154155
else:
155156
raise ValueError(f"Unable to load model from {model_or_path}.")
156-
return cls(path)
157+
return cls(path, config=config)
157158

158159
def fit(
159160
self,

tests/onnxruntime/test_quantization.py

+8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AutoQuantizationConfig,
3131
ORTConfig,
3232
ORTModelForCausalLM,
33+
ORTModelForFeatureExtraction,
3334
ORTModelForSeq2SeqLM,
3435
ORTModelForSequenceClassification,
3536
ORTQuantizer,
@@ -52,6 +53,13 @@ class ORTQuantizerTest(unittest.TestCase):
5253
"optimum/distilbert-base-uncased-finetuned-sst-2-english"
5354
)
5455
},
56+
"ort_model_with_onnx_model_in_subfolder": {
57+
"model_or_path": ORTModelForFeatureExtraction.from_pretrained(
58+
"sentence-transformers/all-MiniLM-L6-v2",
59+
subfolder="onnx",
60+
file_name="model.onnx",
61+
)
62+
},
5563
}
5664

5765
@parameterized.expand(LOAD_CONFIGURATION.items())

0 commit comments

Comments
 (0)