From 369c01f679c466c54893cc39f114e89a8a702fff Mon Sep 17 00:00:00 2001
From: "Lin, Fanli" <fanli.lin@intel.com>
Date: Sat, 11 May 2024 03:51:21 -0700
Subject: [PATCH 1/4] add xpu support

---
 optimum/intel/ipex/modeling_base.py | 36 ++++++++++++++++++++++++++---
 1 file changed, 33 insertions(+), 3 deletions(-)

diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py
index 2b739ea502..3a37059a19 100644
--- a/optimum/intel/ipex/modeling_base.py
+++ b/optimum/intel/ipex/modeling_base.py
@@ -44,6 +44,7 @@
 from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
 from transformers.models.auto.auto_factory import _get_model_class as get_model_class
 from transformers.utils import WEIGHTS_NAME
+from transformers import is_torch_xpu_available
 
 from optimum.exporters import TasksManager
 from optimum.modeling_base import OptimizedModel
@@ -128,10 +129,37 @@ def __init__(
         **kwargs,
     ):
         OptimizedModel.__init__(self, model=model, config=config)
-        # To do: add XPU support
-        self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-        self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
+        device_map = kwargs.pop("device_map", None)
+        if device_map is None:
+            if is_torch_xpu_available(check_device=True):
+                self._device = torch.device("xpu:0")
+            elif torch.cuda.is_available():
+                self._device = torch.device("cuda:0")
+            else:
+                self._device = torch.device("cpu")
+        else:
+            if isinstance(device_map, torch.device):
+                self._device = device_map
+            elif isinstance(device_map, str):
+                if device_map in ["auto", "balanced", "balanced_low_0", "sequential"]:
+                    raise ValueError(
+                        "When passing device_map as a string, the value needs to be a device name (e.g. cpu, xpu:0). "
+                        f"'auto', 'balanced', 'balanced_low_0', 'sequential' are not supported."
+                    )
+                self._device = torch.device(device_map)
+            elif isinstance(device_map, int):
+                if is_torch_xpu_available(check_device=True):
+                    self._device = torch.device(f"xpu:{device_map}")
+                elif torch.cuda.is_available():
+                    self._device = torch.device(f"cuda:{device_map}")
+                else:
+                    self._device = torch.device("cpu")
+            else:
+                raise ValueError(
+                    f"device_map should be either be a string, an integer or a torch.device object, but found {type(device_map)}"
+                )
         self.model.to(self._device)
+        self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
         self.model_save_dir = model_save_dir
         self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
 
@@ -319,6 +347,8 @@ def _init_warmup(self):
         if not self._is_ipex_exported:
             use_cache = "past_key_values" in self.input_names
             dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
+            if "cpu" not in str(self._device):
+                dummy_inputs = {name: tensor.to(self._device) for name, tensor in dummy_inputs.items()}
             for _ in range(2):
                 self(**dummy_inputs)
 

From 08ce3103b5af35f8bb1b7c54c2877637e03cfd12 Mon Sep 17 00:00:00 2001
From: Fanli Lin <fanli0116@gmail.com>
Date: Wed, 15 May 2024 13:48:01 +0800
Subject: [PATCH 2/4] Apply suggestions from code review

no device_map

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
---
 optimum/intel/ipex/modeling_base.py | 22 ----------------------
 1 file changed, 22 deletions(-)

diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py
index 3a37059a19..7ca3a1fc43 100644
--- a/optimum/intel/ipex/modeling_base.py
+++ b/optimum/intel/ipex/modeling_base.py
@@ -129,7 +129,6 @@ def __init__(
         **kwargs,
     ):
         OptimizedModel.__init__(self, model=model, config=config)
-        device_map = kwargs.pop("device_map", None)
         if device_map is None:
             if is_torch_xpu_available(check_device=True):
                 self._device = torch.device("xpu:0")
@@ -137,27 +136,6 @@ def __init__(
                 self._device = torch.device("cuda:0")
             else:
                 self._device = torch.device("cpu")
-        else:
-            if isinstance(device_map, torch.device):
-                self._device = device_map
-            elif isinstance(device_map, str):
-                if device_map in ["auto", "balanced", "balanced_low_0", "sequential"]:
-                    raise ValueError(
-                        "When passing device_map as a string, the value needs to be a device name (e.g. cpu, xpu:0). "
-                        f"'auto', 'balanced', 'balanced_low_0', 'sequential' are not supported."
-                    )
-                self._device = torch.device(device_map)
-            elif isinstance(device_map, int):
-                if is_torch_xpu_available(check_device=True):
-                    self._device = torch.device(f"xpu:{device_map}")
-                elif torch.cuda.is_available():
-                    self._device = torch.device(f"cuda:{device_map}")
-                else:
-                    self._device = torch.device("cpu")
-            else:
-                raise ValueError(
-                    f"device_map should be either be a string, an integer or a torch.device object, but found {type(device_map)}"
-                )
         self.model.to(self._device)
         self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
         self.model_save_dir = model_save_dir

From be967d4d9bf2ea450fd9384837eb3f18123622b0 Mon Sep 17 00:00:00 2001
From: "Lin, Fanli" <fanli.lin@intel.com>
Date: Wed, 15 May 2024 00:47:09 -0700
Subject: [PATCH 3/4] add recursive_to_device

---
 optimum/intel/ipex/modeling_base.py   | 19 +++++++++----------
 optimum/intel/utils/modeling_utils.py | 13 +++++++++++++
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py
index 7ca3a1fc43..f67b39bca3 100644
--- a/optimum/intel/ipex/modeling_base.py
+++ b/optimum/intel/ipex/modeling_base.py
@@ -39,12 +39,12 @@
     GenerationConfig,
     GenerationMixin,
     PretrainedConfig,
+    is_torch_xpu_available,
 )
 from transformers.dynamic_module_utils import get_class_from_dynamic_module
 from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
 from transformers.models.auto.auto_factory import _get_model_class as get_model_class
 from transformers.utils import WEIGHTS_NAME
-from transformers import is_torch_xpu_available
 
 from optimum.exporters import TasksManager
 from optimum.modeling_base import OptimizedModel
@@ -53,7 +53,7 @@
 from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
 from ..generation.modeling import prepare_jit_inputs
 from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
-from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
+from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device
 
 
 logger = logging.getLogger(__name__)
@@ -129,13 +129,12 @@ def __init__(
         **kwargs,
     ):
         OptimizedModel.__init__(self, model=model, config=config)
-        if device_map is None:
-            if is_torch_xpu_available(check_device=True):
-                self._device = torch.device("xpu:0")
-            elif torch.cuda.is_available():
-                self._device = torch.device("cuda:0")
-            else:
-                self._device = torch.device("cpu")
+        if is_torch_xpu_available(check_device=True):
+            self._device = torch.device("xpu:0")
+        elif torch.cuda.is_available():
+            self._device = torch.device("cuda:0")
+        else:
+            self._device = torch.device("cpu")
         self.model.to(self._device)
         self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
         self.model_save_dir = model_save_dir
@@ -326,7 +325,7 @@ def _init_warmup(self):
             use_cache = "past_key_values" in self.input_names
             dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
             if "cpu" not in str(self._device):
-                dummy_inputs = {name: tensor.to(self._device) for name, tensor in dummy_inputs.items()}
+                dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
             for _ in range(2):
                 self(**dummy_inputs)
 
diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py
index 99ad42aafa..aeed66e63c 100644
--- a/optimum/intel/utils/modeling_utils.py
+++ b/optimum/intel/utils/modeling_utils.py
@@ -169,3 +169,16 @@ def get_model_device(model: torch.nn.Module) -> torch.device:
         # The model had no parameters at all, doesn't matter which device to choose
         device = torch.device("cpu")
     return device
+
+
+def recursive_to_device(value, device):
+    """
+    Recursivley move the tensor element in `value` to `device`
+    """
+    if isinstance(value, (tuple, list)):
+        return type(value)(recursive_to_device(v, device) for v in value)
+    elif isinstance(value, dict):
+        return type(value)({k: recursive_to_device(v, device) for k, v in value.items()})
+    elif isinstance(value, torch.Tensor):
+        return value.to(device)
+    return value

From c617464f58947d0641624dbe6f4c84d6f840d093 Mon Sep 17 00:00:00 2001
From: Fanli Lin <fanli0116@gmail.com>
Date: Wed, 15 May 2024 22:01:47 +0800
Subject: [PATCH 4/4] Apply suggestions from code review

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
---
 optimum/intel/ipex/modeling_base.py   | 2 +-
 optimum/intel/utils/modeling_utils.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py
index f67b39bca3..c6719ed9cb 100644
--- a/optimum/intel/ipex/modeling_base.py
+++ b/optimum/intel/ipex/modeling_base.py
@@ -324,7 +324,7 @@ def _init_warmup(self):
         if not self._is_ipex_exported:
             use_cache = "past_key_values" in self.input_names
             dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
-            if "cpu" not in str(self._device):
+            if self._device.type != "cpu":
                 dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
             for _ in range(2):
                 self(**dummy_inputs)
diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py
index aeed66e63c..a2cd728354 100644
--- a/optimum/intel/utils/modeling_utils.py
+++ b/optimum/intel/utils/modeling_utils.py
@@ -178,7 +178,7 @@ def recursive_to_device(value, device):
     if isinstance(value, (tuple, list)):
         return type(value)(recursive_to_device(v, device) for v in value)
     elif isinstance(value, dict):
-        return type(value)({k: recursive_to_device(v, device) for k, v in value.items()})
+        return {k: recursive_to_device(v, device) for k, v in value.items()}
     elif isinstance(value, torch.Tensor):
         return value.to(device)
     return value