Skip to content

Commit be967d4

Browse files
committed
add recursive_to_device
1 parent 08ce310 commit be967d4

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

optimum/intel/ipex/modeling_base.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@
3939
GenerationConfig,
4040
GenerationMixin,
4141
PretrainedConfig,
42+
is_torch_xpu_available,
4243
)
4344
from transformers.dynamic_module_utils import get_class_from_dynamic_module
4445
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
4546
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
4647
from transformers.utils import WEIGHTS_NAME
47-
from transformers import is_torch_xpu_available
4848

4949
from optimum.exporters import TasksManager
5050
from optimum.modeling_base import OptimizedModel
@@ -53,7 +53,7 @@
5353
from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
5454
from ..generation.modeling import prepare_jit_inputs
5555
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
56-
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
56+
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device
5757

5858

5959
logger = logging.getLogger(__name__)
@@ -129,13 +129,12 @@ def __init__(
129129
**kwargs,
130130
):
131131
OptimizedModel.__init__(self, model=model, config=config)
132-
if device_map is None:
133-
if is_torch_xpu_available(check_device=True):
134-
self._device = torch.device("xpu:0")
135-
elif torch.cuda.is_available():
136-
self._device = torch.device("cuda:0")
137-
else:
138-
self._device = torch.device("cpu")
132+
if is_torch_xpu_available(check_device=True):
133+
self._device = torch.device("xpu:0")
134+
elif torch.cuda.is_available():
135+
self._device = torch.device("cuda:0")
136+
else:
137+
self._device = torch.device("cpu")
139138
self.model.to(self._device)
140139
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
141140
self.model_save_dir = model_save_dir
@@ -326,7 +325,7 @@ def _init_warmup(self):
326325
use_cache = "past_key_values" in self.input_names
327326
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
328327
if "cpu" not in str(self._device):
329-
dummy_inputs = {name: tensor.to(self._device) for name, tensor in dummy_inputs.items()}
328+
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
330329
for _ in range(2):
331330
self(**dummy_inputs)
332331

optimum/intel/utils/modeling_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,16 @@ def get_model_device(model: torch.nn.Module) -> torch.device:
169169
# The model had no parameters at all, doesn't matter which device to choose
170170
device = torch.device("cpu")
171171
return device
172+
173+
174+
def recursive_to_device(value, device):
175+
"""
176+
Recursivley move the tensor element in `value` to `device`
177+
"""
178+
if isinstance(value, (tuple, list)):
179+
return type(value)(recursive_to_device(v, device) for v in value)
180+
elif isinstance(value, dict):
181+
return type(value)({k: recursive_to_device(v, device) for k, v in value.items()})
182+
elif isinstance(value, torch.Tensor):
183+
return value.to(device)
184+
return value

0 commit comments

Comments
 (0)