|
44 | 44 | from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
45 | 45 | from transformers.models.auto.auto_factory import _get_model_class as get_model_class
|
46 | 46 | from transformers.utils import WEIGHTS_NAME
|
| 47 | +from transformers import is_torch_xpu_available |
47 | 48 |
|
48 | 49 | from optimum.exporters import TasksManager
|
49 | 50 | from optimum.modeling_base import OptimizedModel
|
@@ -128,10 +129,37 @@ def __init__(
|
128 | 129 | **kwargs,
|
129 | 130 | ):
|
130 | 131 | OptimizedModel.__init__(self, model=model, config=config)
|
131 |
| - # To do: add XPU support |
132 |
| - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
133 |
| - self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 |
| 132 | + device_map = kwargs.pop("device_map", None) |
| 133 | + if device_map is None: |
| 134 | + if is_torch_xpu_available(check_device=True): |
| 135 | + self._device = torch.device("xpu:0") |
| 136 | + elif torch.cuda.is_available(): |
| 137 | + self._device = torch.device("cuda:0") |
| 138 | + else: |
| 139 | + self._device = torch.device("cpu") |
| 140 | + else: |
| 141 | + if isinstance(device_map, torch.device): |
| 142 | + self._device = device_map |
| 143 | + elif isinstance(device_map, str): |
| 144 | + if device_map in ["auto", "balanced", "balanced_low_0", "sequential"]: |
| 145 | + raise ValueError( |
| 146 | + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, xpu:0). " |
| 147 | + f"'auto', 'balanced', 'balanced_low_0', 'sequential' are not supported." |
| 148 | + ) |
| 149 | + self._device = torch.device(device_map) |
| 150 | + elif isinstance(device_map, int): |
| 151 | + if is_torch_xpu_available(check_device=True): |
| 152 | + self._device = torch.device(f"xpu:{device_map}") |
| 153 | + elif torch.cuda.is_available(): |
| 154 | + self._device = torch.device(f"cuda:{device_map}") |
| 155 | + else: |
| 156 | + self._device = torch.device("cpu") |
| 157 | + else: |
| 158 | + raise ValueError( |
| 159 | + f"device_map should be either be a string, an integer or a torch.device object, but found {type(device_map)}" |
| 160 | + ) |
134 | 161 | self.model.to(self._device)
|
| 162 | + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 |
135 | 163 | self.model_save_dir = model_save_dir
|
136 | 164 | self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
|
137 | 165 |
|
@@ -319,6 +347,8 @@ def _init_warmup(self):
|
319 | 347 | if not self._is_ipex_exported:
|
320 | 348 | use_cache = "past_key_values" in self.input_names
|
321 | 349 | dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
|
| 350 | + if "cpu" not in str(self._device): |
| 351 | + dummy_inputs = {name: tensor.to(self._device) for name, tensor in dummy_inputs.items()} |
322 | 352 | for _ in range(2):
|
323 | 353 | self(**dummy_inputs)
|
324 | 354 |
|
|
0 commit comments