From a1fc5ffaf35b0616a06f7654588cf1376b7595b4 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Sun, 28 Jan 2024 18:22:15 +0400 Subject: [PATCH 1/9] skip compression weights tests for nncf==2.8.0 and reworked logic of optimization stateful PyTorch models --- optimum/exporters/openvino/convert.py | 24 +------- optimum/exporters/openvino/model_patcher.py | 21 +++++-- optimum/intel/openvino/quantization.py | 65 +++++++++++++-------- tests/openvino/test_quantization.py | 9 ++- 4 files changed, 67 insertions(+), 52 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index a36c22520c..5e73b1f0c3 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -31,14 +31,7 @@ from optimum.exporters.onnx.model_patcher import DecoderModelPatcher from optimum.utils import is_diffusers_available -from ...intel.utils.import_utils import ( - _torch_version, - _transformers_version, - is_nncf_available, - is_optimum_version, - is_torch_version, - is_transformers_version, -) +from ...intel.utils.import_utils import is_nncf_available, is_optimum_version from .model_patcher import patch_model_with_bettertransformer from .stateful import ensure_stateful_is_available, patch_stateful from .utils import ( @@ -329,19 +322,8 @@ def export_pytorch( logger.info(f"Using framework PyTorch: {torch.__version__}") output = Path(output) - if stateful: - if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"): - COLOR_RED = "\033[1;31m" - COLOR_RESET = "\033[0m" - logger.warning( - COLOR_RED - + "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. " - f"This Python environment has Transformers {_transformers_version} and PyTorch {_torch_version}. " - "Consider upgrading PyTorch and Transformers, for example by running " - "`pip install --upgrade --upgrade-strategy eager optimum[openvino,nncf]`, and export the model again" - + COLOR_RESET - ) - + is_model_stateful = hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True + if stateful and not is_model_stateful: # Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect # both of them are applied to demonstrate the best performance. # TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation. diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 37106eacf8..e4d42cd4cd 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -14,16 +14,27 @@ import logging as log -from optimum.intel.utils.import_utils import is_torch_version +from optimum.intel.utils.import_utils import ( + is_torch_version, + is_transformers_version, + _torch_version, + _transformers_version, +) def patch_model_with_bettertransformer(model): - if is_torch_version("<", "2.0"): + if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"): + COLOR_RED = "\033[1;31m" + COLOR_RESET = "\033[0m" log.warn( - "integration Scaled Dot Product Attention optimization supported only with torch > 2.0." - "Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention" - "It is recommended to upgrade PyTorch version for using stateful model or use stateful=False" + COLOR_RED + + "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. " + f"This Python environment has Transformers {_transformers_version} and PyTorch {_torch_version}. " + "Consider upgrading PyTorch and Transformers, for example by running " + "`pip install --upgrade --upgrade-strategy eager optimum[openvino,nncf]`, and export the model again" + + COLOR_RESET ) + # model already has required SDPA implementation if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": return model diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index cf816193c9..8abf998209 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -24,7 +24,7 @@ import transformers from accelerate.data_loader import DataLoaderStateMixin from datasets import Dataset, load_dataset -from nncf import NNCFConfig, compress_weights +from nncf import NNCFConfig from nncf.torch import create_compressed_model, register_default_init_args, register_module from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk from nncf.torch.initialization import PTInitializingDataLoader @@ -34,11 +34,13 @@ from transformers import DataCollator, PreTrainedModel, default_data_collator from transformers.pytorch_utils import Conv1D +from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.tasks import TasksManager from optimum.quantization_base import OptimumQuantizer from ...exporters.openvino import export, export_pytorch_via_onnx -from ...exporters.openvino.stateful import ensure_export_task_support_stateful +from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer +from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available from ..utils.constant import _TASK_ALIASES from .configuration import OVConfig from .modeling_base import OVBaseModel @@ -348,9 +350,7 @@ def _quantize_ovcausallm( self.model.model, quantization_dataset, model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"), - fast_bias_correction=True - if not kwargs.get("fast_bias_correction") - else kwargs.get("fast_bias_correction"), + fast_bias_correction=True if not kwargs.get("fast_bias_correction") else kwargs.get("fast_bias_correction"), **kwargs, ) self.model.model = quantized_model @@ -392,13 +392,44 @@ def _quantize_torchmodel( if file_name is None and quantization_config.save_onnx_model else Path(ov_file_name).with_suffix(".onnx") ) + + task = self.task + model = self.model + self.model.config.save_pretrained(save_directory) + if task.startswith("text-generation"): + onnx_config = onnx_config_class( + model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache + ) + if model.config.use_cache: + task = "text-generation-with-past" + else: + onnx_config = onnx_config_class(model.config) + + stateful = ensure_stateful_is_available() and ensure_export_task_support_stateful(task) + if weights_only: - if getattr(self.model.config, "tie_word_embeddings", True): - # to fix problem with shared embedding weights in nncf compress_weights() - self.model.tie_weights() - compressed_model = compress_weights(self.model) - self.model = compressed_model + from torch.utils._pytree import tree_map + + if stateful: + # patch model before weight compression + model = patch_model_with_bettertransformer(model) + + dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") + device = self.model.device + dummy_inputs = tree_map( + lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs + ) + check_dummy_inputs_are_allowed(model, dummy_inputs) + + nncf.compress_weights(self.model, dataset=nncf.Dataset([dummy_inputs])) else: + if stateful: + logger.warn( + "Quantization algorithm does not support optimized stateful models. " + "The original model without optimization will be quantized and export." + ) + stateful = False + calibration_dataloader = self._get_calibration_dataloader( calibration_dataset=calibration_dataset, batch_size=batch_size, @@ -415,18 +446,6 @@ def _quantize_torchmodel( ) compressed_model = controller.strip(do_copy=False) - task = self.task - model = self.model - self.model.config.save_pretrained(save_directory) - if task.startswith("text-generation"): - onnx_config = onnx_config_class( - model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache - ) - if model.config.use_cache: - task = "text-generation-with-past" - else: - onnx_config = onnx_config_class(model.config) - model_path = save_directory / (onnx_file_name if quantization_config.save_onnx_model else ov_file_name) onnx_path = save_directory / onnx_file_name export_fn = export if not quantization_config.save_onnx_model else export_pytorch_via_onnx @@ -434,7 +453,7 @@ def _quantize_torchmodel( opset = max(opset, MIN_ONNX_QDQ_OPSET) kwargs = {} if not quantization_config.save_onnx_model: - kwargs = {"stateful": ensure_export_task_support_stateful(task)} + kwargs = {"stateful": stateful} _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs) if is_onnx: # Load and save the compressed model diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index d5d01da605..108559824c 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -152,9 +152,7 @@ class OVWeightCompressionTest(unittest.TestCase): ) SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 365),) - SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ( - (OVModelForCausalLM, "opt125m", 64, 477), - ) + SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 477),) SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( (OVModelForCausalLM, "gpt2"), @@ -174,6 +172,11 @@ class OVWeightCompressionTest(unittest.TestCase): @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): + import nncf + + if nncf.__version__ == "2.8.0": + self.skipTest("https://github.com/openvinotoolkit/nncf/issues/2432") + task = model_cls.export_feature with tempfile.TemporaryDirectory() as tmp_dir: From 3593e516cbf217e9f72494fc4661c5c123f4847b Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Sun, 28 Jan 2024 18:47:38 +0400 Subject: [PATCH 2/9] black happy --- optimum/intel/openvino/quantization.py | 4 +++- tests/openvino/test_quantization.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 8abf998209..5006c3a36d 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -350,7 +350,9 @@ def _quantize_ovcausallm( self.model.model, quantization_dataset, model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"), - fast_bias_correction=True if not kwargs.get("fast_bias_correction") else kwargs.get("fast_bias_correction"), + fast_bias_correction=( + True if not kwargs.get("fast_bias_correction") else kwargs.get("fast_bias_correction") + ), **kwargs, ) self.model.model = quantized_model diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 108559824c..3c5f33aed8 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -152,7 +152,9 @@ class OVWeightCompressionTest(unittest.TestCase): ) SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 365),) - SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 477),) + SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ( + (OVModelForCausalLM, "opt125m", 64, 477), + ) SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( (OVModelForCausalLM, "gpt2"), From f63e987e72bfa88b1d99c979bc7e248d8f235ad4 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Sun, 28 Jan 2024 18:56:19 +0400 Subject: [PATCH 3/9] ruff happy --- optimum/exporters/openvino/model_patcher.py | 4 ++-- optimum/intel/openvino/quantization.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index e4d42cd4cd..6086a68abf 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -15,10 +15,10 @@ import logging as log from optimum.intel.utils.import_utils import ( - is_torch_version, - is_transformers_version, _torch_version, _transformers_version, + is_torch_version, + is_transformers_version, ) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 5006c3a36d..c9b1170d4b 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -417,13 +417,13 @@ def _quantize_torchmodel( model = patch_model_with_bettertransformer(model) dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - device = self.model.device + device = model.device dummy_inputs = tree_map( lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs ) check_dummy_inputs_are_allowed(model, dummy_inputs) - nncf.compress_weights(self.model, dataset=nncf.Dataset([dummy_inputs])) + nncf.compress_weights(model, dataset=nncf.Dataset([dummy_inputs])) else: if stateful: logger.warn( @@ -443,10 +443,10 @@ def _quantize_torchmodel( quantization_config.add_input_info(model_inputs) nncf_config = NNCFConfig.from_dict(quantization_config.__dict__) nncf_config = register_default_init_args(nncf_config, calibration_dataloader) - controller, compressed_model = create_compressed_model( - self.model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk + controller, model = create_compressed_model( + model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk ) - compressed_model = controller.strip(do_copy=False) + model = controller.strip(do_copy=False) model_path = save_directory / (onnx_file_name if quantization_config.save_onnx_model else ov_file_name) onnx_path = save_directory / onnx_file_name From 472bf443ac01aed0f02c91ac78c7dfe867380f13 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 5 Feb 2024 15:44:45 +0400 Subject: [PATCH 4/9] updated nncf version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 33fe656630..fc6eba8729 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ "transformers>=4.34.0", ], "openvino": ["openvino>=2023.2", "onnx", "onnxruntime", "transformers>=4.36.0", "optimum>=1.16.1"], - "nncf": ["nncf>=2.7.0"], + "nncf": ["nncf @ git+https://github.com/openvinotoolkit/nncf.git"], "ipex": ["intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, From f1681b034cdee4d3ff1e66b194b649b50cb65540 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 5 Feb 2024 16:27:11 +0400 Subject: [PATCH 5/9] replied to comments --- optimum/intel/openvino/quantization.py | 7 ++----- tests/openvino/test_quantization.py | 5 ----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index c9b1170d4b..35bf909042 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -30,6 +30,7 @@ from nncf.torch.initialization import PTInitializingDataLoader from openvino._offline_transformations import compress_quantize_weights_transformation from openvino.runtime import Core, Tensor +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader, RandomSampler from transformers import DataCollator, PreTrainedModel, default_data_collator from transformers.pytorch_utils import Conv1D @@ -350,9 +351,7 @@ def _quantize_ovcausallm( self.model.model, quantization_dataset, model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"), - fast_bias_correction=( - True if not kwargs.get("fast_bias_correction") else kwargs.get("fast_bias_correction") - ), + fast_bias_correction=kwargs.get("fast_bias_correction", True), **kwargs, ) self.model.model = quantized_model @@ -410,8 +409,6 @@ def _quantize_torchmodel( stateful = ensure_stateful_is_available() and ensure_export_task_support_stateful(task) if weights_only: - from torch.utils._pytree import tree_map - if stateful: # patch model before weight compression model = patch_model_with_bettertransformer(model) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 3c5f33aed8..d5d01da605 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -174,11 +174,6 @@ class OVWeightCompressionTest(unittest.TestCase): @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): - import nncf - - if nncf.__version__ == "2.8.0": - self.skipTest("https://github.com/openvinotoolkit/nncf/issues/2432") - task = model_cls.export_feature with tempfile.TemporaryDirectory() as tmp_dir: From 29de5393988a9d28d7fa845863d25e6f2f387a78 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 5 Feb 2024 16:52:37 +0400 Subject: [PATCH 6/9] replied comments --- optimum/exporters/openvino/model_patcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 6086a68abf..f953771a7a 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -23,6 +23,10 @@ def patch_model_with_bettertransformer(model): + # check that the model has not yet been pathced + if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + return model + if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"): COLOR_RED = "\033[1;31m" COLOR_RESET = "\033[0m" From 309c3bc0cad8884087c25f1eb8a3081fe931a405 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 5 Feb 2024 17:41:51 +0400 Subject: [PATCH 7/9] typo --- optimum/exporters/openvino/convert.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 5e73b1f0c3..a885eb698e 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -322,8 +322,7 @@ def export_pytorch( logger.info(f"Using framework PyTorch: {torch.__version__}") output = Path(output) - is_model_stateful = hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True - if stateful and not is_model_stateful: + if stateful: # Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect # both of them are applied to demonstrate the best performance. # TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation. From c21db3e901897245f197b465f41112a42762c537 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 5 Feb 2024 20:54:56 +0400 Subject: [PATCH 8/9] cherry pick fixes for tests from PR 538 --- tests/openvino/test_quantization.py | 4 ++-- tests/openvino/utils_tests.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index d5d01da605..61e4e38e4c 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -61,10 +61,10 @@ class OVQuantizerTest(unittest.TestCase): - # TODO : add models + # TODO : add models, enable OVModelForCausalLM. SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 32, 35), - (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 41, 23), + # (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 41, 23), ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 30ed92ba46..11f79a989c 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -103,15 +103,15 @@ "bert": (70,), "roberta": (68,), "albert": (84,), - "vit": (62,), + "vit": (64,), "blenderbot": (70,), "gpt2": (46,), - "wav2vec2": (30,), + "wav2vec2": (34,), "distilbert": (66,), "t5": (64, 104, 84), - "stable-diffusion": (148, 8, 8, 64), - "stable-diffusion-xl": (296, 8, 8, 66), - "stable-diffusion-xl-refiner": (296, 8, 8, 66), + "stable-diffusion": (242, 34, 42, 64), + "stable-diffusion-xl": (366, 34, 42, 66), + "stable-diffusion-xl-refiner": (366, 34, 42, 66), } _ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (64, 477)} From 90b61d2a165f04e73b56a37d7ee7266fa5dbfb58 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Thu, 8 Feb 2024 10:41:12 +0400 Subject: [PATCH 9/9] replied to comments --- optimum/intel/openvino/quantization.py | 3 ++- optimum/intel/utils/modeling_utils.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 35bf909042..6b360e059a 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -43,6 +43,7 @@ from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available from ..utils.constant import _TASK_ALIASES +from ..utils.modeling_utils import get_model_device from .configuration import OVConfig from .modeling_base import OVBaseModel from .modeling_decoder import OVBaseDecoderModel @@ -414,7 +415,7 @@ def _quantize_torchmodel( model = patch_model_with_bettertransformer(model) dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - device = model.device + device = get_model_device(model) dummy_inputs = tree_map( lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs ) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 1a3b6fbede..99ad42aafa 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -148,3 +148,24 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"): elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask return model + + +def get_model_device(model: torch.nn.Module) -> torch.device: + """ + Determines the device on which a PyTorch model is currently residing. + + Args: + model: The PyTorch model to query. + + Returns: + torch.device: The device where the model's parameters are located. + + Raises: + StopIteration: If the model has no parameters. + """ + try: + device = next(model.parameters()).device + except StopIteration: + # The model had no parameters at all, doesn't matter which device to choose + device = torch.device("cpu") + return device