Skip to content

Commit cda4908

Browse files
authored
Fix reshaping unet if timestep is 0d tensor (#1083)
1 parent 106a5b7 commit cda4908

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

optimum/intel/openvino/modeling_diffusion.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,8 @@ def _reshape_unet(
657657
for inputs in model.inputs:
658658
shapes[inputs] = inputs.get_partial_shape()
659659
if inputs.get_any_name() == "timestep":
660-
shapes[inputs][0] = 1
660+
if shapes[inputs].rank == 1:
661+
shapes[inputs][0] = 1
661662
elif inputs.get_any_name() == "sample":
662663
in_channels = self.unet.config.get("in_channels", None)
663664
if in_channels is None:

0 commit comments

Comments
 (0)