Skip to content

Commit e6cf59e

Browse files
committed
use recursive_to_device instead
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 3d807cb commit e6cf59e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

optimum/intel/generation/modeling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from ..utils.constant import _TASK_ALIASES
3636
from ..utils.import_utils import is_torch_version
37-
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
37+
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device
3838

3939

4040
logger = logging.getLogger(__name__)
@@ -64,7 +64,7 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
6464
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
6565

6666
return {
67-
key: dummy_inputs[key].to(model.device)
67+
key: recursive_to_device(dummy_inputs[key], model.device)
6868
for key in signature.parameters
6969
if dummy_inputs.get(key, None) is not None
7070
}

0 commit comments

Comments
 (0)