22
22
import openvino
23
23
import torch
24
24
import transformers
25
- from accelerate .data_loader import DataLoaderStateMixin
26
- from datasets import Dataset , load_dataset
27
25
from nncf import NNCFConfig , compress_weights
28
26
from nncf .torch import create_compressed_model , register_default_init_args , register_module
29
27
from nncf .torch .dynamic_graph .io_handling import wrap_nncf_model_inputs_with_objwalk
33
31
from torch .utils .data import DataLoader , RandomSampler
34
32
from transformers import DataCollator , PreTrainedModel , default_data_collator
35
33
from transformers .pytorch_utils import Conv1D
34
+ from transformers .utils import is_accelerate_available
36
35
37
36
from optimum .exporters .tasks import TasksManager
38
37
from optimum .quantization_base import OptimumQuantizer
39
38
40
39
from ...exporters .openvino import export , export_pytorch_via_onnx
41
40
from ...exporters .openvino .stateful import ensure_export_task_support_stateful
42
41
from ..utils .constant import _TASK_ALIASES
42
+ from ..utils .import_utils import DATASETS_IMPORT_ERROR , is_datasets_available
43
43
from .configuration import OVConfig
44
44
from .modeling_base import OVBaseModel
45
45
from .modeling_decoder import OVBaseDecoderModel
51
51
)
52
52
53
53
54
+ if is_datasets_available ():
55
+ from datasets import Dataset
56
+
54
57
COMPRESSION_OPTIONS = {
55
58
"int8" : {"mode" : nncf .CompressWeightsMode .INT8 },
56
59
"int4_sym_g128" : {"mode" : nncf .CompressWeightsMode .INT4_SYM , "group_size" : 128 },
@@ -72,8 +75,11 @@ def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
72
75
@property
73
76
def batch_size (self ):
74
77
batch_size = self ._data_loader .batch_size
75
- if batch_size is None and isinstance (self ._data_loader , DataLoaderStateMixin ):
76
- batch_size = self ._data_loader .total_batch_size
78
+ if is_accelerate_available ():
79
+ from accelerate .data_loader import DataLoaderStateMixin
80
+
81
+ if batch_size is None and isinstance (self ._data_loader , DataLoaderStateMixin ):
82
+ batch_size = self ._data_loader .total_batch_size
77
83
return batch_size
78
84
79
85
@@ -155,7 +161,7 @@ def from_pretrained(cls, model: PreTrainedModel, **kwargs):
155
161
156
162
def quantize (
157
163
self ,
158
- calibration_dataset : Dataset = None ,
164
+ calibration_dataset : " Dataset" = None ,
159
165
save_directory : Union [str , Path ] = None ,
160
166
quantization_config : OVConfig = None ,
161
167
file_name : Optional [str ] = None ,
@@ -268,7 +274,7 @@ def _get_compression_options(self, config: OVConfig):
268
274
269
275
def _quantize_ovbasemodel (
270
276
self ,
271
- calibration_dataset : Dataset ,
277
+ calibration_dataset : " Dataset" ,
272
278
save_directory : Union [str , Path ],
273
279
batch_size : int = 1 ,
274
280
data_collator : Optional [DataCollator ] = None ,
@@ -304,7 +310,7 @@ def _quantize_ovbasemodel(
304
310
305
311
def _quantize_ovcausallm (
306
312
self ,
307
- calibration_dataset : Dataset ,
313
+ calibration_dataset : " Dataset" ,
308
314
save_directory : Union [str , Path ],
309
315
batch_size : int = 1 ,
310
316
data_collator : Optional [DataCollator ] = None ,
@@ -358,7 +364,7 @@ def _quantize_ovcausallm(
358
364
359
365
def _quantize_torchmodel (
360
366
self ,
361
- calibration_dataset : Dataset ,
367
+ calibration_dataset : " Dataset" ,
362
368
save_directory : Union [str , Path ],
363
369
quantization_config : OVConfig = None ,
364
370
file_name : Optional [str ] = None ,
@@ -482,7 +488,7 @@ def get_calibration_dataset(
482
488
preprocess_batch : bool = True ,
483
489
use_auth_token : bool = False ,
484
490
cache_dir : Optional [str ] = None ,
485
- ) -> Dataset :
491
+ ) -> " Dataset" :
486
492
"""
487
493
Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
488
494
@@ -507,6 +513,10 @@ def get_calibration_dataset(
507
513
Returns:
508
514
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
509
515
"""
516
+ if not is_datasets_available ():
517
+ raise ValueError (DATASETS_IMPORT_ERROR .format ("OVQuantizer.get_calibration_dataset" ))
518
+ from datasets import load_dataset
519
+
510
520
calibration_dataset = load_dataset (
511
521
dataset_name ,
512
522
name = dataset_config_name ,
@@ -526,7 +536,7 @@ def get_calibration_dataset(
526
536
527
537
def _get_calibration_dataloader (
528
538
self ,
529
- calibration_dataset : Dataset ,
539
+ calibration_dataset : " Dataset" ,
530
540
batch_size : int ,
531
541
remove_unused_columns : bool ,
532
542
data_collator : Optional [DataCollator ] = None ,
@@ -543,6 +553,6 @@ def _get_calibration_dataloader(
543
553
)
544
554
return OVDataLoader (calibration_dataloader )
545
555
546
- def _remove_unused_columns (self , dataset : Dataset ):
556
+ def _remove_unused_columns (self , dataset : " Dataset" ):
547
557
ignored_columns = list (set (dataset .column_names ) - set (self ._signature_columns ))
548
558
return dataset .remove_columns (ignored_columns )
0 commit comments