Skip to content

Commit 3d807cb

Browse files
committed
fix crash in warmup for xpu
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent ea6fa42 commit 3d807cb

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

optimum/intel/generation/modeling.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
6363

6464
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
6565

66-
return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
66+
return {
67+
key: dummy_inputs[key].to(model.device)
68+
for key in signature.parameters
69+
if dummy_inputs.get(key, None) is not None
70+
}
6771

6872

6973
def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):

0 commit comments

Comments
 (0)