Skip to content

Commit 280703c

Browse files
eaidovaecharlaix
andauthoredFeb 16, 2024
Relax dependency on accelerate and datasets in OVQuantizer (#547)
* Relax dependency on accelerate and datasets in OVQuantizer * additional guard on datasets import * Update optimum/intel/openvino/quantization.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update quantization.py --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 55d419f commit 280703c

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed
 

‎optimum/intel/openvino/quantization.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616
import logging
1717
import os
1818
from pathlib import Path
19-
from typing import Any, Callable, Dict, Optional, Tuple, Union
19+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
2020

2121
import nncf
2222
import openvino
2323
import torch
2424
import transformers
25-
from accelerate.data_loader import DataLoaderStateMixin
26-
from datasets import Dataset, load_dataset
2725
from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric
2826
from nncf.torch import create_compressed_model, register_default_init_args, register_module
2927
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
@@ -34,6 +32,7 @@
3432
from torch.utils.data import DataLoader, RandomSampler
3533
from transformers import AutoTokenizer, DataCollator, PreTrainedModel, default_data_collator
3634
from transformers.pytorch_utils import Conv1D
35+
from transformers.utils import is_accelerate_available
3736

3837
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
3938
from optimum.exporters.tasks import TasksManager
@@ -43,6 +42,7 @@
4342
from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer
4443
from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available
4544
from ..utils.constant import _TASK_ALIASES
45+
from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available
4646
from ..utils.modeling_utils import get_model_device
4747
from .configuration import OVConfig, OVWeightQuantizationConfig
4848
from .modeling_base import OVBaseModel
@@ -54,6 +54,10 @@
5454
)
5555

5656

57+
if is_datasets_available():
58+
if TYPE_CHECKING:
59+
from datasets import Dataset
60+
5761
register_module(ignored_algorithms=[])(Conv1D)
5862

5963
core = Core()
@@ -67,8 +71,11 @@ def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
6771
@property
6872
def batch_size(self):
6973
batch_size = self._data_loader.batch_size
70-
if batch_size is None and isinstance(self._data_loader, DataLoaderStateMixin):
71-
batch_size = self._data_loader.total_batch_size
74+
if is_accelerate_available():
75+
from accelerate.data_loader import DataLoaderStateMixin
76+
77+
if batch_size is None and isinstance(self._data_loader, DataLoaderStateMixin):
78+
batch_size = self._data_loader.total_batch_size
7279
return batch_size
7380

7481

@@ -150,7 +157,7 @@ def from_pretrained(cls, model: PreTrainedModel, **kwargs):
150157

151158
def quantize(
152159
self,
153-
calibration_dataset: Dataset = None,
160+
calibration_dataset: "Dataset" = None,
154161
save_directory: Union[str, Path] = None,
155162
ov_config: OVConfig = None,
156163
file_name: Optional[str] = None,
@@ -252,7 +259,7 @@ def quantize(
252259

253260
def _quantize_ovbasemodel(
254261
self,
255-
calibration_dataset: Dataset,
262+
calibration_dataset: "Dataset",
256263
save_directory: Union[str, Path],
257264
batch_size: int = 1,
258265
data_collator: Optional[DataCollator] = None,
@@ -310,7 +317,7 @@ def _quantize_ovbasemodel(
310317

311318
def _quantize_torchmodel(
312319
self,
313-
calibration_dataset: Dataset,
320+
calibration_dataset: "Dataset",
314321
save_directory: Union[str, Path],
315322
ov_config: OVConfig = None,
316323
file_name: Optional[str] = None,
@@ -452,7 +459,7 @@ def get_calibration_dataset(
452459
preprocess_batch: bool = True,
453460
use_auth_token: bool = False,
454461
cache_dir: Optional[str] = None,
455-
) -> Dataset:
462+
) -> "Dataset":
456463
"""
457464
Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
458465
@@ -477,6 +484,10 @@ def get_calibration_dataset(
477484
Returns:
478485
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
479486
"""
487+
if not is_datasets_available():
488+
raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer.get_calibration_dataset"))
489+
from datasets import load_dataset
490+
480491
calibration_dataset = load_dataset(
481492
dataset_name,
482493
name=dataset_config_name,
@@ -496,7 +507,7 @@ def get_calibration_dataset(
496507

497508
def _get_calibration_dataloader(
498509
self,
499-
calibration_dataset: Dataset,
510+
calibration_dataset: "Dataset",
500511
batch_size: int,
501512
remove_unused_columns: bool,
502513
data_collator: Optional[DataCollator] = None,
@@ -513,7 +524,7 @@ def _get_calibration_dataloader(
513524
)
514525
return OVDataLoader(calibration_dataloader)
515526

516-
def _remove_unused_columns(self, dataset: Dataset):
527+
def _remove_unused_columns(self, dataset: "Dataset"):
517528
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
518529
return dataset.remove_columns(ignored_columns)
519530

‎optimum/intel/utils/import_utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@
146146
_timm_available = False
147147

148148

149+
_datasets_available = importlib.util.find_spec("datasets") is not None
150+
_datasets_version = "N/A"
151+
152+
if _datasets_available:
153+
try:
154+
_datasets_version = importlib_metadata.version("datasets")
155+
except importlib_metadata.PackageNotFoundError:
156+
_datasets_available = False
157+
158+
149159
def is_transformers_available():
150160
return _transformers_available
151161

@@ -182,6 +192,10 @@ def is_timm_available():
182192
return _timm_available
183193

184194

195+
def is_datasets_available():
196+
return _datasets_available
197+
198+
185199
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
186200
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
187201
"""
@@ -298,6 +312,11 @@ def is_timm_version(operation: str, version: str):
298312
`pip install neural-compressor`. Please note that you may need to restart your runtime after installation.
299313
"""
300314

315+
DATASETS_IMPORT_ERROR = """
316+
{0} requires the datasets library but it was not found in your environment. You can install it with pip:
317+
`pip install datasets`. Please note that you may need to restart your runtime after installation.
318+
"""
319+
301320
BACKENDS_MAPPING = OrderedDict(
302321
[
303322
("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),

0 commit comments

Comments
 (0)