Skip to content

Commit 90b61d2

Browse files
committed
replied to comments
1 parent c21db3e commit 90b61d2

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

optimum/intel/openvino/quantization.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer
4444
from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available
4545
from ..utils.constant import _TASK_ALIASES
46+
from ..utils.modeling_utils import get_model_device
4647
from .configuration import OVConfig
4748
from .modeling_base import OVBaseModel
4849
from .modeling_decoder import OVBaseDecoderModel
@@ -414,7 +415,7 @@ def _quantize_torchmodel(
414415
model = patch_model_with_bettertransformer(model)
415416

416417
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
417-
device = model.device
418+
device = get_model_device(model)
418419
dummy_inputs = tree_map(
419420
lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs
420421
)

optimum/intel/utils/modeling_utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,24 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"):
148148
elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}:
149149
model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
150150
return model
151+
152+
153+
def get_model_device(model: torch.nn.Module) -> torch.device:
154+
"""
155+
Determines the device on which a PyTorch model is currently residing.
156+
157+
Args:
158+
model: The PyTorch model to query.
159+
160+
Returns:
161+
torch.device: The device where the model's parameters are located.
162+
163+
Raises:
164+
StopIteration: If the model has no parameters.
165+
"""
166+
try:
167+
device = next(model.parameters()).device
168+
except StopIteration:
169+
# The model had no parameters at all, doesn't matter which device to choose
170+
device = torch.device("cpu")
171+
return device

0 commit comments

Comments
 (0)