16
16
import logging
17
17
import os
18
18
from pathlib import Path
19
- from typing import Any , Callable , Dict , Optional , Tuple , Union
19
+ from typing import TYPE_CHECKING , Any , Callable , Dict , Optional , Tuple , Union
20
20
21
21
import nncf
22
22
import openvino
23
23
import torch
24
24
import transformers
25
- from accelerate .data_loader import DataLoaderStateMixin
26
- from datasets import Dataset , load_dataset
27
25
from nncf import CompressWeightsMode , IgnoredScope , NNCFConfig , SensitivityMetric
28
26
from nncf .torch import create_compressed_model , register_default_init_args , register_module
29
27
from nncf .torch .dynamic_graph .io_handling import wrap_nncf_model_inputs_with_objwalk
34
32
from torch .utils .data import DataLoader , RandomSampler
35
33
from transformers import AutoTokenizer , DataCollator , PreTrainedModel , default_data_collator
36
34
from transformers .pytorch_utils import Conv1D
35
+ from transformers .utils import is_accelerate_available
37
36
38
37
from optimum .exporters .onnx .convert import check_dummy_inputs_are_allowed
39
38
from optimum .exporters .tasks import TasksManager
43
42
from ...exporters .openvino .model_patcher import patch_model_with_bettertransformer
44
43
from ...exporters .openvino .stateful import ensure_export_task_support_stateful , ensure_stateful_is_available
45
44
from ..utils .constant import _TASK_ALIASES
45
+ from ..utils .import_utils import DATASETS_IMPORT_ERROR , is_datasets_available
46
46
from ..utils .modeling_utils import get_model_device
47
47
from .configuration import OVConfig , OVWeightQuantizationConfig
48
48
from .modeling_base import OVBaseModel
54
54
)
55
55
56
56
57
+ if is_datasets_available ():
58
+ if TYPE_CHECKING :
59
+ from datasets import Dataset
60
+
57
61
register_module (ignored_algorithms = [])(Conv1D )
58
62
59
63
core = Core ()
@@ -67,8 +71,11 @@ def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
67
71
@property
68
72
def batch_size (self ):
69
73
batch_size = self ._data_loader .batch_size
70
- if batch_size is None and isinstance (self ._data_loader , DataLoaderStateMixin ):
71
- batch_size = self ._data_loader .total_batch_size
74
+ if is_accelerate_available ():
75
+ from accelerate .data_loader import DataLoaderStateMixin
76
+
77
+ if batch_size is None and isinstance (self ._data_loader , DataLoaderStateMixin ):
78
+ batch_size = self ._data_loader .total_batch_size
72
79
return batch_size
73
80
74
81
@@ -150,7 +157,7 @@ def from_pretrained(cls, model: PreTrainedModel, **kwargs):
150
157
151
158
def quantize (
152
159
self ,
153
- calibration_dataset : Dataset = None ,
160
+ calibration_dataset : " Dataset" = None ,
154
161
save_directory : Union [str , Path ] = None ,
155
162
ov_config : OVConfig = None ,
156
163
file_name : Optional [str ] = None ,
@@ -252,7 +259,7 @@ def quantize(
252
259
253
260
def _quantize_ovbasemodel (
254
261
self ,
255
- calibration_dataset : Dataset ,
262
+ calibration_dataset : " Dataset" ,
256
263
save_directory : Union [str , Path ],
257
264
batch_size : int = 1 ,
258
265
data_collator : Optional [DataCollator ] = None ,
@@ -310,7 +317,7 @@ def _quantize_ovbasemodel(
310
317
311
318
def _quantize_torchmodel (
312
319
self ,
313
- calibration_dataset : Dataset ,
320
+ calibration_dataset : " Dataset" ,
314
321
save_directory : Union [str , Path ],
315
322
ov_config : OVConfig = None ,
316
323
file_name : Optional [str ] = None ,
@@ -452,7 +459,7 @@ def get_calibration_dataset(
452
459
preprocess_batch : bool = True ,
453
460
use_auth_token : bool = False ,
454
461
cache_dir : Optional [str ] = None ,
455
- ) -> Dataset :
462
+ ) -> " Dataset" :
456
463
"""
457
464
Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
458
465
@@ -477,6 +484,10 @@ def get_calibration_dataset(
477
484
Returns:
478
485
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
479
486
"""
487
+ if not is_datasets_available ():
488
+ raise ValueError (DATASETS_IMPORT_ERROR .format ("OVQuantizer.get_calibration_dataset" ))
489
+ from datasets import load_dataset
490
+
480
491
calibration_dataset = load_dataset (
481
492
dataset_name ,
482
493
name = dataset_config_name ,
@@ -496,7 +507,7 @@ def get_calibration_dataset(
496
507
497
508
def _get_calibration_dataloader (
498
509
self ,
499
- calibration_dataset : Dataset ,
510
+ calibration_dataset : " Dataset" ,
500
511
batch_size : int ,
501
512
remove_unused_columns : bool ,
502
513
data_collator : Optional [DataCollator ] = None ,
@@ -513,7 +524,7 @@ def _get_calibration_dataloader(
513
524
)
514
525
return OVDataLoader (calibration_dataloader )
515
526
516
- def _remove_unused_columns (self , dataset : Dataset ):
527
+ def _remove_unused_columns (self , dataset : " Dataset" ):
517
528
ignored_columns = list (set (dataset .column_names ) - set (self ._signature_columns ))
518
529
return dataset .remove_columns (ignored_columns )
519
530
0 commit comments