Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add XPU support for IPEXModel.from_pretrained #704

Merged
merged 4 commits into from
May 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME
from transformers import is_torch_xpu_available

from optimum.exporters import TasksManager
from optimum.modeling_base import OptimizedModel
Expand Down Expand Up @@ -128,10 +129,37 @@ def __init__(
**kwargs,
):
OptimizedModel.__init__(self, model=model, config=config)
# To do: add XPU support
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
device_map = kwargs.pop("device_map", None)
if device_map is None:
if is_torch_xpu_available(check_device=True):
self._device = torch.device("xpu:0")
elif torch.cuda.is_available():
self._device = torch.device("cuda:0")
else:
self._device = torch.device("cpu")
else:
if isinstance(device_map, torch.device):
self._device = device_map
elif isinstance(device_map, str):
if device_map in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, xpu:0). "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' are not supported."
)
self._device = torch.device(device_map)
elif isinstance(device_map, int):
if is_torch_xpu_available(check_device=True):
self._device = torch.device(f"xpu:{device_map}")
elif torch.cuda.is_available():
self._device = torch.device(f"cuda:{device_map}")
else:
self._device = torch.device("cpu")
else:
raise ValueError(
f"device_map should be either be a string, an integer or a torch.device object, but found {type(device_map)}"
)
self.model.to(self._device)
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
self.model_save_dir = model_save_dir
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)

Expand Down Expand Up @@ -319,6 +347,8 @@ def _init_warmup(self):
if not self._is_ipex_exported:
use_cache = "past_key_values" in self.input_names
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
if "cpu" not in str(self._device):
dummy_inputs = {name: tensor.to(self._device) for name, tensor in dummy_inputs.items()}
for _ in range(2):
self(**dummy_inputs)

Expand Down
Loading