Skip to content

Commit 5fa9602

Browse files
authored
Fix warmup for xpu (#1090)
* fix crash in warmup for xpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use recursive_to_device instead Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 87c431c commit 5fa9602

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

optimum/intel/generation/modeling.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from ..utils.constant import _TASK_ALIASES
3636
from ..utils.import_utils import is_torch_version
37-
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
37+
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device
3838

3939

4040
logger = logging.getLogger(__name__)
@@ -63,7 +63,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
6363

6464
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
6565

66-
return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
66+
return {
67+
key: recursive_to_device(dummy_inputs[key], model.device)
68+
for key in signature.parameters
69+
if dummy_inputs.get(key, None) is not None
70+
}
6771

6872

6973
def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):

0 commit comments

Comments
 (0)