55
55
)
56
56
from optimum .utils .normalized_config import NormalizedConfig , NormalizedTextConfig , NormalizedVisionConfig
57
57
58
- from ...intel .utils .import_utils import _transformers_version , is_transformers_version
58
+ from ...intel .utils .import_utils import _transformers_version , is_diffusers_version , is_transformers_version
59
59
from .model_patcher import (
60
60
AquilaModelPatcher ,
61
61
ArcticModelPatcher ,
@@ -1681,7 +1681,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1681
1681
img_ids_height = self .height // 2
1682
1682
img_ids_width = self .width // 2
1683
1683
return self .random_int_tensor (
1684
- [self .batch_size , img_ids_height * img_ids_width , 3 ],
1684
+ [self .batch_size , img_ids_height * img_ids_width , 3 ]
1685
+ if is_diffusers_version ("<" , "0.31.0" )
1686
+ else [img_ids_height * img_ids_width , 3 ],
1685
1687
min_value = 0 ,
1686
1688
max_value = min (img_ids_height , img_ids_width ),
1687
1689
framework = framework ,
@@ -1704,7 +1706,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1704
1706
if input_name == "txt_ids" :
1705
1707
import torch
1706
1708
1707
- shape = [self .batch_size , self .sequence_length , 3 ]
1709
+ shape = (
1710
+ [self .batch_size , self .sequence_length , 3 ]
1711
+ if is_diffusers_version ("<" , "0.31.0" )
1712
+ else [self .sequence_length , 3 ]
1713
+ )
1708
1714
dtype = DTYPE_MAPPER .pt (float_dtype )
1709
1715
return torch .full (shape , 0 , dtype = dtype )
1710
1716
return super ().generate (input_name , framework , int_dtype , float_dtype )
@@ -1724,8 +1730,14 @@ def inputs(self):
1724
1730
common_inputs = super ().inputs
1725
1731
common_inputs .pop ("sample" , None )
1726
1732
common_inputs ["hidden_states" ] = {0 : "batch_size" , 1 : "packed_height_width" }
1727
- common_inputs ["txt_ids" ] = {0 : "batch_size" , 1 : "sequence_length" }
1728
- common_inputs ["img_ids" ] = {0 : "batch_size" , 1 : "packed_height_width" }
1733
+ common_inputs ["txt_ids" ] = (
1734
+ {0 : "batch_size" , 1 : "sequence_length" } if is_diffusers_version ("<" , "0.31.0" ) else {0 : "sequence_length" }
1735
+ )
1736
+ common_inputs ["img_ids" ] = (
1737
+ {0 : "batch_size" , 1 : "packed_height_width" }
1738
+ if is_diffusers_version ("<" , "0.31.0" )
1739
+ else {0 : "packed_height_width" }
1740
+ )
1729
1741
if getattr (self ._normalized_config , "guidance_embeds" , False ):
1730
1742
common_inputs ["guidance" ] = {0 : "batch_size" }
1731
1743
return common_inputs
0 commit comments