21
21
from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
22
22
23
23
import onnx
24
- from datasets import Dataset , load_dataset
25
24
from packaging .version import Version , parse
26
25
from transformers import AutoConfig
27
26
28
27
from onnxruntime import __version__ as ort_version
29
28
from onnxruntime .quantization import CalibrationDataReader , QuantFormat , QuantizationMode , QuantType
30
29
from onnxruntime .quantization .onnx_quantizer import ONNXQuantizer
31
30
from onnxruntime .quantization .qdq_quantizer import QDQQuantizer
31
+ from optimum .utils .import_utils import requires_backends
32
32
33
33
from ..quantization_base import OptimumQuantizer
34
34
from ..utils .save_utils import maybe_save_preprocessors
40
40
41
41
42
42
if TYPE_CHECKING :
43
+ from datasets import Dataset
43
44
from transformers import PretrainedConfig
44
45
45
46
LOGGER = logging .getLogger (__name__ )
48
49
class ORTCalibrationDataReader (CalibrationDataReader ):
49
50
__slots__ = ["batch_size" , "dataset" , "_dataset_iter" ]
50
51
51
- def __init__ (self , dataset : Dataset , batch_size : int = 1 ):
52
+ def __init__ (self , dataset : " Dataset" , batch_size : int = 1 ):
52
53
if dataset is None :
53
54
raise ValueError ("Provided dataset is None." )
54
55
@@ -158,7 +159,7 @@ def from_pretrained(
158
159
159
160
def fit (
160
161
self ,
161
- dataset : Dataset ,
162
+ dataset : " Dataset" ,
162
163
calibration_config : CalibrationConfig ,
163
164
onnx_augmented_model_name : Union [str , Path ] = "augmented_model.onnx" ,
164
165
operators_to_quantize : Optional [List [str ]] = None ,
@@ -212,7 +213,7 @@ def fit(
212
213
213
214
def partial_fit (
214
215
self ,
215
- dataset : Dataset ,
216
+ dataset : " Dataset" ,
216
217
calibration_config : CalibrationConfig ,
217
218
onnx_augmented_model_name : Union [str , Path ] = "augmented_model.onnx" ,
218
219
operators_to_quantize : Optional [List [str ]] = None ,
@@ -428,7 +429,7 @@ def get_calibration_dataset(
428
429
seed : int = 2016 ,
429
430
use_auth_token : Optional [Union [bool , str ]] = None ,
430
431
token : Optional [Union [bool , str ]] = None ,
431
- ) -> Dataset :
432
+ ) -> " Dataset" :
432
433
"""
433
434
Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
434
435
@@ -474,6 +475,10 @@ def get_calibration_dataset(
474
475
"provided."
475
476
)
476
477
478
+ requires_backends (self , ["datasets" ])
479
+
480
+ from datasets import load_dataset
481
+
477
482
calib_dataset = load_dataset (
478
483
dataset_name ,
479
484
name = dataset_config_name ,
@@ -492,7 +497,7 @@ def get_calibration_dataset(
492
497
493
498
return self .clean_calibration_dataset (processed_calib_dataset )
494
499
495
- def clean_calibration_dataset (self , dataset : Dataset ) -> Dataset :
500
+ def clean_calibration_dataset (self , dataset : " Dataset" ) -> " Dataset" :
496
501
model = onnx .load (self .onnx_model_path )
497
502
model_inputs = {input .name for input in model .graph .input }
498
503
ignored_columns = list (set (dataset .column_names ) - model_inputs )
0 commit comments