Skip to content

Commit 3cfbc38

Browse files
faaanyecharlaix
andauthored
add XPU support for IPEXModel.from_pretrained (#704)
* add xpu support * Apply suggestions from code review no device_map Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * add recursive_to_device * Apply suggestions from code review Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 2b902bb commit 3cfbc38

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

optimum/intel/ipex/modeling_base.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
GenerationConfig,
4040
GenerationMixin,
4141
PretrainedConfig,
42+
is_torch_xpu_available,
4243
)
4344
from transformers.dynamic_module_utils import get_class_from_dynamic_module
4445
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
@@ -52,7 +53,7 @@
5253
from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
5354
from ..generation.modeling import prepare_jit_inputs
5455
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
55-
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
56+
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device
5657

5758

5859
logger = logging.getLogger(__name__)
@@ -128,10 +129,14 @@ def __init__(
128129
**kwargs,
129130
):
130131
OptimizedModel.__init__(self, model=model, config=config)
131-
# To do: add XPU support
132-
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133-
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
132+
if is_torch_xpu_available(check_device=True):
133+
self._device = torch.device("xpu:0")
134+
elif torch.cuda.is_available():
135+
self._device = torch.device("cuda:0")
136+
else:
137+
self._device = torch.device("cpu")
134138
self.model.to(self._device)
139+
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
135140
self.model_save_dir = model_save_dir
136141
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
137142

@@ -321,6 +326,8 @@ def _init_warmup(self):
321326
if not self._is_ipex_exported:
322327
use_cache = "past_key_values" in self.input_names
323328
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
329+
if self._device.type != "cpu":
330+
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
324331
for _ in range(2):
325332
self(**dummy_inputs)
326333

optimum/intel/utils/modeling_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,16 @@ def get_model_device(model: torch.nn.Module) -> torch.device:
169169
# The model had no parameters at all, doesn't matter which device to choose
170170
device = torch.device("cpu")
171171
return device
172+
173+
174+
def recursive_to_device(value, device):
175+
"""
176+
Recursivley move the tensor element in `value` to `device`
177+
"""
178+
if isinstance(value, (tuple, list)):
179+
return type(value)(recursive_to_device(v, device) for v in value)
180+
elif isinstance(value, dict):
181+
return {k: recursive_to_device(v, device) for k, v in value.items()}
182+
elif isinstance(value, torch.Tensor):
183+
return value.to(device)
184+
return value

0 commit comments

Comments
 (0)