Skip to content

Commit 369c01f

Browse files
committed
add xpu support
1 parent e6fadb1 commit 369c01f

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

optimum/intel/ipex/modeling_base.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
4545
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
4646
from transformers.utils import WEIGHTS_NAME
47+
from transformers import is_torch_xpu_available
4748

4849
from optimum.exporters import TasksManager
4950
from optimum.modeling_base import OptimizedModel
@@ -128,10 +129,37 @@ def __init__(
128129
**kwargs,
129130
):
130131
OptimizedModel.__init__(self, model=model, config=config)
131-
# To do: add XPU support
132-
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133-
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
132+
device_map = kwargs.pop("device_map", None)
133+
if device_map is None:
134+
if is_torch_xpu_available(check_device=True):
135+
self._device = torch.device("xpu:0")
136+
elif torch.cuda.is_available():
137+
self._device = torch.device("cuda:0")
138+
else:
139+
self._device = torch.device("cpu")
140+
else:
141+
if isinstance(device_map, torch.device):
142+
self._device = device_map
143+
elif isinstance(device_map, str):
144+
if device_map in ["auto", "balanced", "balanced_low_0", "sequential"]:
145+
raise ValueError(
146+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, xpu:0). "
147+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' are not supported."
148+
)
149+
self._device = torch.device(device_map)
150+
elif isinstance(device_map, int):
151+
if is_torch_xpu_available(check_device=True):
152+
self._device = torch.device(f"xpu:{device_map}")
153+
elif torch.cuda.is_available():
154+
self._device = torch.device(f"cuda:{device_map}")
155+
else:
156+
self._device = torch.device("cpu")
157+
else:
158+
raise ValueError(
159+
f"device_map should be either be a string, an integer or a torch.device object, but found {type(device_map)}"
160+
)
134161
self.model.to(self._device)
162+
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
135163
self.model_save_dir = model_save_dir
136164
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
137165

@@ -319,6 +347,8 @@ def _init_warmup(self):
319347
if not self._is_ipex_exported:
320348
use_cache = "past_key_values" in self.input_names
321349
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
350+
if "cpu" not in str(self._device):
351+
dummy_inputs = {name: tensor.to(self._device) for name, tensor in dummy_inputs.items()}
322352
for _ in range(2):
323353
self(**dummy_inputs)
324354

0 commit comments

Comments
 (0)