|
39 | 39 | GenerationConfig,
|
40 | 40 | GenerationMixin,
|
41 | 41 | PretrainedConfig,
|
| 42 | + is_torch_xpu_available, |
42 | 43 | )
|
43 | 44 | from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
44 | 45 | from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
45 | 46 | from transformers.models.auto.auto_factory import _get_model_class as get_model_class
|
46 | 47 | from transformers.utils import WEIGHTS_NAME
|
47 |
| -from transformers import is_torch_xpu_available |
48 | 48 |
|
49 | 49 | from optimum.exporters import TasksManager
|
50 | 50 | from optimum.modeling_base import OptimizedModel
|
|
53 | 53 | from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
|
54 | 54 | from ..generation.modeling import prepare_jit_inputs
|
55 | 55 | from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
|
56 |
| -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask |
| 56 | +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device |
57 | 57 |
|
58 | 58 |
|
59 | 59 | logger = logging.getLogger(__name__)
|
@@ -129,13 +129,12 @@ def __init__(
|
129 | 129 | **kwargs,
|
130 | 130 | ):
|
131 | 131 | OptimizedModel.__init__(self, model=model, config=config)
|
132 |
| - if device_map is None: |
133 |
| - if is_torch_xpu_available(check_device=True): |
134 |
| - self._device = torch.device("xpu:0") |
135 |
| - elif torch.cuda.is_available(): |
136 |
| - self._device = torch.device("cuda:0") |
137 |
| - else: |
138 |
| - self._device = torch.device("cpu") |
| 132 | + if is_torch_xpu_available(check_device=True): |
| 133 | + self._device = torch.device("xpu:0") |
| 134 | + elif torch.cuda.is_available(): |
| 135 | + self._device = torch.device("cuda:0") |
| 136 | + else: |
| 137 | + self._device = torch.device("cpu") |
139 | 138 | self.model.to(self._device)
|
140 | 139 | self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
|
141 | 140 | self.model_save_dir = model_save_dir
|
@@ -326,7 +325,7 @@ def _init_warmup(self):
|
326 | 325 | use_cache = "past_key_values" in self.input_names
|
327 | 326 | dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
|
328 | 327 | if "cpu" not in str(self._device):
|
329 |
| - dummy_inputs = {name: tensor.to(self._device) for name, tensor in dummy_inputs.items()} |
| 328 | + dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) |
330 | 329 | for _ in range(2):
|
331 | 330 | self(**dummy_inputs)
|
332 | 331 |
|
|
0 commit comments