Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add XPU support for IPEXModel.from_pretrained #704

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
GenerationConfig,
GenerationMixin,
PretrainedConfig,
is_torch_xpu_available,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
Expand All @@ -52,7 +53,7 @@
from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
from ..generation.modeling import prepare_jit_inputs
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -128,10 +129,14 @@ def __init__(
**kwargs,
):
OptimizedModel.__init__(self, model=model, config=config)
# To do: add XPU support
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
if is_torch_xpu_available(check_device=True):
self._device = torch.device("xpu:0")
elif torch.cuda.is_available():
self._device = torch.device("cuda:0")
else:
self._device = torch.device("cpu")
self.model.to(self._device)
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
self.model_save_dir = model_save_dir
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)

Expand Down Expand Up @@ -319,6 +324,8 @@ def _init_warmup(self):
if not self._is_ipex_exported:
use_cache = "past_key_values" in self.input_names
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
if self._device.type != "cpu":
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
for _ in range(2):
self(**dummy_inputs)

Expand Down
13 changes: 13 additions & 0 deletions optimum/intel/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,16 @@ def get_model_device(model: torch.nn.Module) -> torch.device:
# The model had no parameters at all, doesn't matter which device to choose
device = torch.device("cpu")
return device


def recursive_to_device(value, device):
"""
Recursivley move the tensor element in `value` to `device`
"""
if isinstance(value, (tuple, list)):
return type(value)(recursive_to_device(v, device) for v in value)
elif isinstance(value, dict):
return {k: recursive_to_device(v, device) for k, v in value.items()}
elif isinstance(value, torch.Tensor):
return value.to(device)
return value
Loading