Skip to content

Commit cbbda3e

Browse files
Fix ort config instantiation (from_pretrained) and saving (save_pretrained) (#1865)
* fix ort config instatiation (from_dict) and saving (to_dict) * added tests for quantization with ort config * style * handle empty quant dictionary
1 parent f300865 commit cbbda3e

File tree

3 files changed

+80
-33
lines changed

3 files changed

+80
-33
lines changed

.github/workflows/test_cli.yml

+18-15
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ name: Optimum CLI / Python - Test
44

55
on:
66
push:
7-
branches: [ main ]
7+
branches: [main]
88
pull_request:
9-
branches: [ main ]
9+
branches: [main]
1010

1111
concurrency:
1212
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
@@ -22,17 +22,20 @@ jobs:
2222

2323
runs-on: ${{ matrix.os }}
2424
steps:
25-
- uses: actions/checkout@v2
26-
- name: Setup Python ${{ matrix.python-version }}
27-
uses: actions/setup-python@v2
28-
with:
29-
python-version: ${{ matrix.python-version }}
30-
- name: Install dependencies
31-
run: |
32-
python -m pip install --upgrade pip
33-
pip install .[tests,exporters,exporters-tf]
34-
- name: Test with unittest
35-
working-directory: tests
36-
run: |
37-
python -m unittest discover -s cli -p 'test_*.py'
25+
- name: Checkout code
26+
uses: actions/checkout@v4
3827

28+
- name: Setup Python ${{ matrix.python-version }}
29+
uses: actions/setup-python@v5
30+
with:
31+
python-version: ${{ matrix.python-version }}
32+
33+
- name: Install dependencies
34+
run: |
35+
pip install --upgrade pip
36+
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
37+
pip install .[tests,exporters,exporters-tf]
38+
39+
- name: Test with pytest
40+
run: |
41+
pytest tests/cli -s -vvvv --durations=0

optimum/onnxruntime/configuration.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from dataclasses import asdict, dataclass, field
1919
from enum import Enum
2020
from pathlib import Path
21-
from typing import Dict, List, Optional, Tuple, Union
21+
from typing import Any, Dict, List, Optional, Tuple, Union
2222

2323
from datasets import Dataset
2424
from packaging.version import Version, parse
@@ -298,6 +298,15 @@ def __post_init__(self):
298298
)
299299
self.operators_to_quantize = operators_to_quantize
300300

301+
if isinstance(self.format, str):
302+
self.format = QuantFormat[self.format]
303+
if isinstance(self.mode, str):
304+
self.mode = QuantizationMode[self.mode]
305+
if isinstance(self.activations_dtype, str):
306+
self.activations_dtype = QuantType[self.activations_dtype]
307+
if isinstance(self.weights_dtype, str):
308+
self.weights_dtype = QuantType[self.weights_dtype]
309+
301310
@staticmethod
302311
def quantization_type_str(activations_dtype: QuantType, weights_dtype: QuantType) -> str:
303312
return (
@@ -984,8 +993,28 @@ def __init__(
984993
self.opset = opset
985994
self.use_external_data_format = use_external_data_format
986995
self.one_external_file = one_external_file
987-
self.optimization = self.dataclass_to_dict(optimization)
988-
self.quantization = self.dataclass_to_dict(quantization)
996+
997+
if isinstance(optimization, dict) and optimization:
998+
self.optimization = OptimizationConfig(**optimization)
999+
elif isinstance(optimization, OptimizationConfig):
1000+
self.optimization = optimization
1001+
elif not optimization:
1002+
self.optimization = None
1003+
else:
1004+
raise ValueError(
1005+
f"Optional argument `optimization` must be a dictionary or an instance of OptimizationConfig, got {type(optimization)}"
1006+
)
1007+
if isinstance(quantization, dict) and quantization:
1008+
self.quantization = QuantizationConfig(**quantization)
1009+
elif isinstance(quantization, QuantizationConfig):
1010+
self.quantization = quantization
1011+
elif not quantization:
1012+
self.quantization = None
1013+
else:
1014+
raise ValueError(
1015+
f"Optional argument `quantization` must be a dictionary or an instance of QuantizationConfig, got {type(quantization)}"
1016+
)
1017+
9891018
self.optimum_version = kwargs.pop("optimum_version", None)
9901019

9911020
@staticmethod
@@ -1002,3 +1031,17 @@ def dataclass_to_dict(config) -> dict:
10021031
v = [elem.name if isinstance(elem, Enum) else elem for elem in v]
10031032
new_config[k] = v
10041033
return new_config
1034+
1035+
def to_dict(self) -> Dict[str, Any]:
1036+
dict_config = {
1037+
"opset": self.opset,
1038+
"use_external_data_format": self.use_external_data_format,
1039+
"one_external_file": self.one_external_file,
1040+
"optimization": self.dataclass_to_dict(self.optimization),
1041+
"quantization": self.dataclass_to_dict(self.quantization),
1042+
}
1043+
1044+
if self.optimum_version:
1045+
dict_config["optimum_version"] = self.optimum_version
1046+
1047+
return dict_config

tests/cli/test_cli.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
import unittest
2222
from pathlib import Path
2323

24-
from onnxruntime import __version__ as ort_version
25-
from packaging.version import Version, parse
26-
2724
import optimum.commands
25+
from optimum.onnxruntime.configuration import AutoQuantizationConfig, ORTConfig
2826

2927

3028
CLI_WIH_CUSTOM_COMMAND_PATH = Path(__file__).parent / "cli_with_custom_command.py"
@@ -83,30 +81,33 @@ def test_optimize_commands(self):
8381

8482
def test_quantize_commands(self):
8583
with tempfile.TemporaryDirectory() as tempdir:
84+
ort_config = ORTConfig(quantization=AutoQuantizationConfig.avx2(is_static=False))
85+
ort_config.save_pretrained(tempdir)
86+
8687
# First export a tiny encoder, decoder only and encoder-decoder
8788
export_commands = [
88-
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-BertModel {tempdir}/encoder",
89+
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-bert {tempdir}/encoder",
8990
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 {tempdir}/decoder",
90-
# f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder",
91+
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder",
9192
]
9293
quantize_commands = [
9394
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder --avx2 -o {tempdir}/quantized_encoder",
9495
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/decoder --avx2 -o {tempdir}/quantized_decoder",
95-
# f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder",
96+
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder",
9697
]
9798

98-
if parse(ort_version) != Version("1.16.0") and parse(ort_version) != Version("1.17.0"):
99-
# Failing on onnxruntime==1.17.0, will be fixed on 1.17.1: https://github.com/microsoft/onnxruntime/pull/19421
100-
export_commands.append(
101-
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder"
102-
)
103-
quantize_commands.append(
104-
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder"
105-
)
99+
quantize_with_config_commands = [
100+
f"optimum-cli onnxruntime quantize --onnx_model hf-internal-testing/tiny-random-bert --c {tempdir}/ort_config.json -o {tempdir}/quantized_encoder_with_config",
101+
f"optimum-cli onnxruntime quantize --onnx_model hf-internal-testing/tiny-random-gpt2 --c {tempdir}/ort_config.json -o {tempdir}/quantized_decoder_with_config",
102+
f"optimum-cli onnxruntime quantize --onnx_model hf-internal-testing/tiny-random-t5 --c {tempdir}/ort_config.json -o {tempdir}/quantized_encoder_decoder_with_config",
103+
]
106104

107-
for export, quantize in zip(export_commands, quantize_commands):
105+
for export, quantize, quantize_with_config in zip(
106+
export_commands, quantize_commands, quantize_with_config_commands
107+
):
108108
subprocess.run(export, shell=True, check=True)
109109
subprocess.run(quantize, shell=True, check=True)
110+
subprocess.run(quantize_with_config, shell=True, check=True)
110111

111112
def _run_command_and_check_content(self, command: str, content: str) -> bool:
112113
proc = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)

0 commit comments

Comments
 (0)