Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove accelerate and onnxruntime from required dependencies #590

Merged
merged 11 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
python -m pip install --upgrade pip
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[openvino,openvino-tokenizers,nncf,tests,diffusers]
pip install .[openvino,openvino-tokenizers,tests,diffusers] onnxruntime
- name: Test with Pytest
run: |
pytest tests/openvino/ --ignore test_modeling_basic
35 changes: 23 additions & 12 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.utils import OptionalDependencyNotAvailable, _LazyModule

from .utils import (
is_accelerate_available,
is_diffusers_available,
is_ipex_available,
is_neural_compressor_available,
Expand Down Expand Up @@ -57,13 +58,19 @@
if not (is_openvino_available() and is_nncf_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_openvino_and_nncf_objects"] = [
"OVQuantizer",
"OVTrainer",
"OVTrainingArguments",
]
_import_structure["utils.dummy_openvino_and_nncf_objects"] = ["OVQuantizer", "OVTrainingArguments"]
else:
_import_structure["openvino"].extend(["OVQuantizer", "OVTrainingArguments"])


try:
if not (is_openvino_available() and is_nncf_available() and is_accelerate_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_openvino_and_nncf_objects"] = ["OVTrainer"]
else:
_import_structure["openvino"].extend(["OVQuantizer", "OVTrainer", "OVTrainingArguments"])
_import_structure["openvino"].extend(["OVTrainer"])


try:
if not (is_openvino_available() and is_diffusers_available()):
Expand Down Expand Up @@ -177,13 +184,17 @@
if not (is_openvino_available() and is_nncf_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_openvino_and_nncf_objects import (
OVQuantizer,
OVTrainer,
OVTrainingArguments,
)
from .utils.dummy_openvino_and_nncf_objects import OVQuantizer, OVTrainingArguments
else:
from .openvino import OVQuantizer, OVTrainingArguments

try:
if not (is_openvino_available() and is_nncf_available() and is_accelerate_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_openvino_and_nncf_objects import OVTrainer
else:
from .openvino import OVQuantizer, OVTrainer, OVTrainingArguments
from .openvino import OVTrainer

try:
if not (is_openvino_available() and is_diffusers_available()):
Expand Down
6 changes: 4 additions & 2 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging

from ..utils.import_utils import is_diffusers_available, is_nncf_available
from ..utils.import_utils import is_accelerate_available, is_diffusers_available, is_nncf_available
from .utils import (
OV_DECODER_NAME,
OV_DECODER_WITH_PAST_NAME,
Expand All @@ -37,9 +37,11 @@
patch_torch_operators()

from .quantization import OVQuantizer
from .trainer import OVTrainer
from .training_args import OVTrainingArguments

if is_accelerate_available():
from .trainer import OVTrainer


from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling import (
Expand Down
14 changes: 10 additions & 4 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import nncf
import openvino
Expand Down Expand Up @@ -56,8 +56,7 @@


if is_datasets_available():
if TYPE_CHECKING:
from datasets import Dataset
from datasets import Dataset

register_module(ignored_algorithms=[])(Conv1D)

Expand Down Expand Up @@ -147,6 +146,7 @@ def __init__(self, model: transformers.PreTrainedModel, task: Optional[str] = No
)
self.task = task or feature
self.seed = seed
# TODO : deprecate input_names
self.input_names = None
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
Expand Down Expand Up @@ -526,9 +526,15 @@ def _get_calibration_dataloader(
data_collator: Optional[DataCollator] = None,
) -> OVDataLoader:
data_collator = data_collator if data_collator is not None else default_data_collator

if not is_datasets_available() or not isinstance(calibration_dataset, Dataset):
logger.warning(
"`remove_unused_columns` set to `False` as calibration_dataset is not an instance of `datasets.Dataset`"
)
remove_unused_columns = False

if remove_unused_columns:
calibration_dataset = self._remove_unused_columns(calibration_dataset)
self.input_names = calibration_dataset.column_names
generator = torch.Generator()
generator.manual_seed(self.seed)
sampler = RandomSampler(calibration_dataset, generator=generator)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_neural_compressor_version,
_torch_version,
compare_versions,
is_accelerate_available,
is_diffusers_available,
is_ipex_available,
is_neural_compressor_available,
Expand Down
6 changes: 3 additions & 3 deletions optimum/intel/utils/dummy_openvino_and_nncf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def from_pretrained(cls, *args, **kwargs):


class OVTrainer(metaclass=DummyObject):
_backends = ["openvino", "nncf"]
_backends = ["openvino", "nncf", "accelerate"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["openvino", "nncf"])
requires_backends(self, ["openvino", "nncf", "accelerate"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino", "nncf"])
requires_backends(cls, ["openvino", "nncf", "accelerate"])


class OVQuantizer(metaclass=DummyObject):
Expand Down
20 changes: 20 additions & 0 deletions optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@
_datasets_available = False


_accelerate_available = importlib.util.find_spec("accelerate") is not None
_accelerate_version = "N/A"

if _accelerate_available:
try:
_accelerate_version = importlib_metadata.version("accelerate")
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False


def is_transformers_available():
return _transformers_available

Expand Down Expand Up @@ -196,6 +206,10 @@ def is_datasets_available():
return _datasets_available


def is_accelerate_available():
return _accelerate_available


# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
Expand Down Expand Up @@ -317,13 +331,19 @@ def is_timm_version(operation: str, version: str):
`pip install datasets`. Please note that you may need to restart your runtime after installation.
"""

ACCELERATE_IMPORT_ERROR = """
{0} requires the accelerate library but it was not found in your environment. You can install it with pip:
`pip install accelerate`. Please note that you may need to restart your runtime after installation.
"""

BACKENDS_MAPPING = OrderedDict(
[
("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),
("ipex", (is_ipex_available, IPEX_IMPORT_ERROR)),
("nncf", (is_nncf_available, NNCF_IMPORT_ERROR)),
("openvino", (is_openvino_available, OPENVINO_IMPORT_ERROR)),
("neural_compressor", (is_neural_compressor_available, NEURAL_COMPRESSOR_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
]
)

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
"datasets>=1.4.0",
"sentencepiece",
"scipy",
"accelerate", # transformers 4.29 require accelerate for PyTorch
]

TESTS_REQUIRE = [
"accelerate",
"pytest",
"parameterized",
"Pillow",
Expand All @@ -39,8 +39,8 @@
QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"]

EXTRAS_REQUIRE = {
"neural-compressor": ["neural-compressor>=2.2.0", "onnx", "onnxruntime<1.15.0"],
"openvino": ["openvino>=2023.3", "onnx", "onnxruntime", "nncf>=2.8.1"],
"neural-compressor": ["neural-compressor>=2.2.0", "onnx", "onnxruntime<1.15.0", "accelerate"],
"openvino": ["openvino>=2023.3", "onnx", "nncf>=2.8.1"],
"openvino-tokenizers": ["openvino-tokenizers[transformers]"],
"nncf": ["nncf>=2.8.1"],
"ipex": ["intel-extension-for-pytorch", "onnx"],
Expand Down
23 changes: 8 additions & 15 deletions tests/openvino/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import floats_tensor
from openvino.runtime.ie_api import CompiledModel
from packaging.version import Version, parse
from parameterized import parameterized
from utils_tests import MODEL_NAMES, SEED

Expand All @@ -46,13 +45,8 @@
OVModelVaeDecoder,
OVModelVaeEncoder,
)
from optimum.onnxruntime import (
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
)
from optimum.utils.import_utils import _diffusers_version
from optimum.intel.utils.import_utils import is_diffusers_version
from optimum.utils.import_utils import is_onnxruntime_available


F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"}
Expand Down Expand Up @@ -167,7 +161,6 @@ def generate_inputs(self, height=128, width=128, batch_size=1):
class OVStableDiffusionImg2ImgPipelineTest(OVStableDiffusionPipelineBaseTest):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionImg2ImgPipeline
ORT_MODEL_CLASS = ORTStableDiffusionImg2ImgPipeline
TASK = "image-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down Expand Up @@ -298,11 +291,13 @@ def test_height_width_properties(self, model_arch: str):
class OVStableDiffusionInpaintPipelineTest(OVStableDiffusionPipelineBaseTest):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionInpaintPipeline
ORT_MODEL_CLASS = ORTStableDiffusionInpaintPipeline
TASK = "inpaint"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(not is_onnxruntime_available(), "this test requires onnxruntime")
def test_compare_diffusers_pipeline(self, model_arch: str):
from optimum.onnxruntime import ORTStableDiffusionInpaintPipeline

model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
batch_size, num_images, height, width = 1, 1, 64, 64
Expand All @@ -329,7 +324,7 @@ def test_compare_diffusers_pipeline(self, model_arch: str):
outputs = pipeline(**inputs, latents=latents).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

ort_pipeline = self.ORT_MODEL_CLASS.from_pretrained(model_id, export=True)
ort_pipeline = ORTStableDiffusionInpaintPipeline.from_pretrained(model_id, export=True)
ort_outputs = ort_pipeline(**inputs, latents=latents).images
self.assertTrue(np.allclose(outputs, ort_outputs, atol=1e-1))

Expand Down Expand Up @@ -358,7 +353,6 @@ def generate_inputs(self, height=128, width=128, batch_size=1):
class OVtableDiffusionXLPipelineTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("stable-diffusion-xl",)
MODEL_CLASS = OVStableDiffusionXLPipeline
ORT_MODEL_CLASS = ORTStableDiffusionXLPipeline
PT_MODEL_CLASS = StableDiffusionXLPipeline
TASK = "text-to-image"

Expand Down Expand Up @@ -444,7 +438,6 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
class OVStableDiffusionXLImg2ImgPipelineTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("stable-diffusion-xl", "stable-diffusion-xl-refiner")
MODEL_CLASS = OVStableDiffusionXLImg2ImgPipeline
ORT_MODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline
PT_MODEL_CLASS = StableDiffusionXLImg2ImgPipeline
TASK = "image-to-image"

Expand Down Expand Up @@ -489,7 +482,7 @@ class OVLatentConsistencyModelPipelineTest(unittest.TestCase):
TASK = "text-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
@unittest.skipIf(is_diffusers_version("<=", "0.21.4"), "not supported with this diffusers version")
def test_compare_to_diffusers(self, model_arch: str):
ov_pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True, ov_config=F32_CONFIG)
self.assertIsInstance(ov_pipeline.text_encoder, OVModelTextEncoder)
Expand Down Expand Up @@ -532,7 +525,7 @@ def test_compare_to_diffusers(self, model_arch: str):
self.assertEqual(pipeline.device.type, ov_pipeline.device)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
@unittest.skipIf(is_diffusers_version("<=", "0.21.4"), "not supported with this diffusers version")
def test_num_images_per_prompt_static_model(self, model_arch: str):
model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False)
Expand Down
Loading