|
39 | 39 | GenerationConfig,
|
40 | 40 | GenerationMixin,
|
41 | 41 | PretrainedConfig,
|
| 42 | + is_torch_xpu_available, |
42 | 43 | )
|
43 | 44 | from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
44 | 45 | from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
|
52 | 53 | from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
|
53 | 54 | from ..generation.modeling import prepare_jit_inputs
|
54 | 55 | from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
|
55 |
| -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask |
| 56 | +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device |
56 | 57 |
|
57 | 58 |
|
58 | 59 | logger = logging.getLogger(__name__)
|
@@ -128,10 +129,14 @@ def __init__(
|
128 | 129 | **kwargs,
|
129 | 130 | ):
|
130 | 131 | OptimizedModel.__init__(self, model=model, config=config)
|
131 |
| - # To do: add XPU support |
132 |
| - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
133 |
| - self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 |
| 132 | + if is_torch_xpu_available(check_device=True): |
| 133 | + self._device = torch.device("xpu:0") |
| 134 | + elif torch.cuda.is_available(): |
| 135 | + self._device = torch.device("cuda:0") |
| 136 | + else: |
| 137 | + self._device = torch.device("cpu") |
134 | 138 | self.model.to(self._device)
|
| 139 | + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 |
135 | 140 | self.model_save_dir = model_save_dir
|
136 | 141 | self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
|
137 | 142 |
|
@@ -321,6 +326,8 @@ def _init_warmup(self):
|
321 | 326 | if not self._is_ipex_exported:
|
322 | 327 | use_cache = "past_key_values" in self.input_names
|
323 | 328 | dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
|
| 329 | + if self._device.type != "cpu": |
| 330 | + dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) |
324 | 331 | for _ in range(2):
|
325 | 332 | self(**dummy_inputs)
|
326 | 333 |
|
|
0 commit comments