Skip to content

Commit d2a5a6a

Browse files
authored
Remove datasets as required dependency (#2087)
* remove datasets required dependency * install datasets when needed * add datasets installed when needed * style * add require dataset * divide datasets tests * import datasets only when needed
1 parent a7a807c commit d2a5a6a

17 files changed

+123
-35
lines changed

.github/workflows/dev_test_benckmark.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
pip install wheel
26-
pip install .[tests,onnxruntime,benchmark]
26+
pip install .[tests,onnxruntime,benchmark] datasets
2727
pip install -U git+https://github.com/huggingface/evaluate
2828
pip install -U git+https://github.com/huggingface/diffusers
2929
pip install -U git+https://github.com/huggingface/transformers

.github/workflows/test_benckmark.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
- name: Install dependencies
3131
run: |
3232
pip install wheel
33-
pip install .[tests,onnxruntime,benchmark]
33+
pip install .[tests,onnxruntime,benchmark] datasets
3434
- name: Test with unittest
3535
run: |
3636
python -m unittest discover --start-directory tests/benchmark --pattern 'test_*.py'

.github/workflows/test_utils.yml

+10-1
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,13 @@ jobs:
3737
- name: Test with pytest
3838
working-directory: tests
3939
run: |
40-
python -m pytest -s -vvvv utils
40+
pytest utils -s -n auto -m "not datasets_test" --durations=0
41+
42+
- name: Install datasets
43+
run: |
44+
pip install datasets
45+
46+
- name: Tests needing datasets
47+
working-directory: tests
48+
run: |
49+
pytest utils -s -n auto -m "datasets_test" --durations=0

optimum/gptq/data.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818

1919
import numpy as np
2020
import torch
21-
from datasets import load_dataset
21+
22+
from optimum.utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available
23+
24+
25+
if is_datasets_available():
26+
from datasets import load_dataset
2227

2328

2429
"""
@@ -113,6 +118,9 @@ def pad_block(block, pads):
113118

114119

115120
def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
121+
if not is_datasets_available():
122+
raise ImportError(DATASETS_IMPORT_ERROR.format("get_wikitext2"))
123+
116124
if split == "train":
117125
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
118126
elif split == "validation":
@@ -132,6 +140,9 @@ def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "trai
132140

133141

134142
def get_c4(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
143+
if not is_datasets_available():
144+
raise ImportError(DATASETS_IMPORT_ERROR.format("get_c4"))
145+
135146
if split == "train":
136147
data = load_dataset("allenai/c4", split="train", data_files={"train": "en/c4-train.00000-of-01024.json.gz"})
137148
elif split == "validation":
@@ -157,6 +168,9 @@ def get_c4(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
157168

158169

159170
def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
171+
if not is_datasets_available():
172+
raise ImportError(DATASETS_IMPORT_ERROR.format("get_c4_new"))
173+
160174
if split == "train":
161175
data = load_dataset("allenai/c4", split="train", data_files={"train": "en/c4-train.00000-of-01024.json.gz"})
162176
elif split == "validation":

optimum/gptq/quantizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
dataset (`Union[List[str], str, Any]`, defaults to `None`):
8989
The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data
9090
(e.g. [{ "input_ids": [ 1, 100, 15, ... ],"attention_mask": [ 1, 1, 1, ... ]},...])
91-
or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'].
91+
or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new'].
9292
group_size (int, defaults to 128):
9393
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
9494
damp_percent (`float`, defaults to `0.1`):

optimum/onnxruntime/configuration.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
from dataclasses import asdict, dataclass, field
1919
from enum import Enum
2020
from pathlib import Path
21-
from typing import Any, Dict, List, Optional, Tuple, Union
21+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
2222

23-
from datasets import Dataset
2423
from packaging.version import Version, parse
2524

2625
from onnxruntime import __version__ as ort_version
@@ -33,6 +32,10 @@
3332
from ..utils import logging
3433

3534

35+
if TYPE_CHECKING:
36+
from datasets import Dataset
37+
38+
3639
logger = logging.get_logger(__name__)
3740

3841
# This value is used to indicate ORT which axis it should use to quantize an operator "per-channel"
@@ -117,7 +120,9 @@ def create_calibrator(
117120

118121
class AutoCalibrationConfig:
119122
@staticmethod
120-
def minmax(dataset: Dataset, moving_average: bool = False, averaging_constant: float = 0.01) -> CalibrationConfig:
123+
def minmax(
124+
dataset: "Dataset", moving_average: bool = False, averaging_constant: float = 0.01
125+
) -> CalibrationConfig:
121126
"""
122127
Args:
123128
dataset (`Dataset`):
@@ -151,7 +156,7 @@ def minmax(dataset: Dataset, moving_average: bool = False, averaging_constant: f
151156

152157
@staticmethod
153158
def entropy(
154-
dataset: Dataset,
159+
dataset: "Dataset",
155160
num_bins: int = 128,
156161
num_quantized_bins: int = 128,
157162
) -> CalibrationConfig:
@@ -188,7 +193,7 @@ def entropy(
188193
)
189194

190195
@staticmethod
191-
def percentiles(dataset: Dataset, num_bins: int = 2048, percentile: float = 99.999) -> CalibrationConfig:
196+
def percentiles(dataset: "Dataset", num_bins: int = 2048, percentile: float = 99.999) -> CalibrationConfig:
192197
"""
193198
Args:
194199
dataset (`Dataset`):

optimum/onnxruntime/model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,20 @@
1414

1515
import logging
1616
import os
17-
from typing import Callable, Dict, List, Optional, Union
17+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
1818

1919
import numpy as np
20-
from datasets import Dataset
2120
from transformers import EvalPrediction
2221
from transformers.trainer_pt_utils import nested_concat
2322
from transformers.trainer_utils import EvalLoopOutput
2423

2524
from onnxruntime import InferenceSession
2625

2726

27+
if TYPE_CHECKING:
28+
from datasets import Dataset
29+
30+
2831
logger = logging.getLogger(__name__)
2932

3033

@@ -59,7 +62,7 @@ def __init__(
5962
self.session = InferenceSession(str(model_path), providers=[execution_provider])
6063
self.onnx_input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
6164

62-
def evaluation_loop(self, dataset: Dataset):
65+
def evaluation_loop(self, dataset: "Dataset"):
6366
"""
6467
Run evaluation and returns metrics and predictions.
6568

optimum/onnxruntime/quantization.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
2222

2323
import onnx
24-
from datasets import Dataset, load_dataset
2524
from packaging.version import Version, parse
2625
from transformers import AutoConfig
2726

2827
from onnxruntime import __version__ as ort_version
2928
from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantizationMode, QuantType
3029
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
3130
from onnxruntime.quantization.qdq_quantizer import QDQQuantizer
31+
from optimum.utils.import_utils import requires_backends
3232

3333
from ..quantization_base import OptimumQuantizer
3434
from ..utils.save_utils import maybe_save_preprocessors
@@ -40,6 +40,7 @@
4040

4141

4242
if TYPE_CHECKING:
43+
from datasets import Dataset
4344
from transformers import PretrainedConfig
4445

4546
LOGGER = logging.getLogger(__name__)
@@ -48,7 +49,7 @@
4849
class ORTCalibrationDataReader(CalibrationDataReader):
4950
__slots__ = ["batch_size", "dataset", "_dataset_iter"]
5051

51-
def __init__(self, dataset: Dataset, batch_size: int = 1):
52+
def __init__(self, dataset: "Dataset", batch_size: int = 1):
5253
if dataset is None:
5354
raise ValueError("Provided dataset is None.")
5455

@@ -158,7 +159,7 @@ def from_pretrained(
158159

159160
def fit(
160161
self,
161-
dataset: Dataset,
162+
dataset: "Dataset",
162163
calibration_config: CalibrationConfig,
163164
onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx",
164165
operators_to_quantize: Optional[List[str]] = None,
@@ -212,7 +213,7 @@ def fit(
212213

213214
def partial_fit(
214215
self,
215-
dataset: Dataset,
216+
dataset: "Dataset",
216217
calibration_config: CalibrationConfig,
217218
onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx",
218219
operators_to_quantize: Optional[List[str]] = None,
@@ -428,7 +429,7 @@ def get_calibration_dataset(
428429
seed: int = 2016,
429430
use_auth_token: Optional[Union[bool, str]] = None,
430431
token: Optional[Union[bool, str]] = None,
431-
) -> Dataset:
432+
) -> "Dataset":
432433
"""
433434
Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
434435
@@ -474,6 +475,10 @@ def get_calibration_dataset(
474475
"provided."
475476
)
476477

478+
requires_backends(self, ["datasets"])
479+
480+
from datasets import load_dataset
481+
477482
calib_dataset = load_dataset(
478483
dataset_name,
479484
name=dataset_config_name,
@@ -492,7 +497,7 @@ def get_calibration_dataset(
492497

493498
return self.clean_calibration_dataset(processed_calib_dataset)
494499

495-
def clean_calibration_dataset(self, dataset: Dataset) -> Dataset:
500+
def clean_calibration_dataset(self, dataset: "Dataset") -> "Dataset":
496501
model = onnx.load(self.onnx_model_path)
497502
model_inputs = {input.name for input in model.graph.input}
498503
ignored_columns = list(set(dataset.column_names) - model_inputs)

optimum/onnxruntime/runs/calibrator.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from typing import Dict, List
2-
3-
from datasets import Dataset
1+
from typing import TYPE_CHECKING, Dict, List
42

53
from ...runs_base import Calibrator
64
from .. import ORTQuantizer
@@ -9,10 +7,14 @@
97
from ..preprocessors.passes import ExcludeGeLUNodes, ExcludeLayerNormNodes, ExcludeNodeAfter, ExcludeNodeFollowedBy
108

119

10+
if TYPE_CHECKING:
11+
from datasets import Dataset
12+
13+
1214
class OnnxRuntimeCalibrator(Calibrator):
1315
def __init__(
1416
self,
15-
calibration_dataset: Dataset,
17+
calibration_dataset: "Dataset",
1618
quantizer: ORTQuantizer,
1719
model_path: str,
1820
qconfig: QuantizationConfig,

optimum/runs_base.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import subprocess
33
from contextlib import contextmanager
44
from time import perf_counter_ns
5-
from typing import Set
5+
from typing import TYPE_CHECKING, Set
66

77
import numpy as np
88
import optuna
99
import torch
1010
import transformers
11-
from datasets import Dataset
1211
from tqdm import trange
1312

1413
from . import version as optimum_version
@@ -21,6 +20,9 @@
2120
from .utils.runs import RunConfig, cpu_info_command
2221

2322

23+
if TYPE_CHECKING:
24+
from datasets import Dataset
25+
2426
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2527

2628

@@ -34,7 +36,7 @@ def get_autoclass_name(task):
3436

3537
class Calibrator:
3638
def __init__(
37-
self, calibration_dataset: Dataset, quantizer, model_path, qconfig, calibration_params, node_exclusion
39+
self, calibration_dataset: "Dataset", quantizer, model_path, qconfig, calibration_params, node_exclusion
3840
):
3941
self.calibration_dataset = calibration_dataset
4042
self.quantizer = quantizer

optimum/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
check_if_transformers_greater,
3636
is_accelerate_available,
3737
is_auto_gptq_available,
38+
is_datasets_available,
3839
is_diffusers_available,
3940
is_onnx_available,
4041
is_onnxruntime_available,

optimum/utils/import_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
6969
_auto_gptq_available = _is_package_available("auto_gptq")
7070
_timm_available = _is_package_available("timm")
7171
_sentence_transformers_available = _is_package_available("sentence_transformers")
72+
_datasets_available = _is_package_available("datasets")
7273

7374
torch_version = None
7475
if is_torch_available():
@@ -131,6 +132,10 @@ def is_sentence_transformers_available():
131132
return _sentence_transformers_available
132133

133134

135+
def is_datasets_available():
136+
return _datasets_available
137+
138+
134139
def is_auto_gptq_available():
135140
if _auto_gptq_available:
136141
version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
@@ -230,6 +235,12 @@ def require_numpy_strictly_lower(package_version: str, message: str):
230235
-U transformers`. Please note that you may need to restart your runtime after installation.
231236
"""
232237

238+
DATASETS_IMPORT_ERROR = """
239+
{0} requires the datasets library but it was not found in your environment. You can install it with pip:
240+
`pip install datasets`. Please note that you may need to restart your runtime after installation.
241+
"""
242+
243+
233244
BACKENDS_MAPPING = OrderedDict(
234245
[
235246
("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),
@@ -245,6 +256,7 @@ def require_numpy_strictly_lower(package_version: str, message: str):
245256
"transformers_434",
246257
(lambda: check_if_transformers_greater("4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")),
247258
),
259+
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
248260
]
249261
)
250262

0 commit comments

Comments
 (0)