@@ -1250,6 +1250,7 @@ def patch_model_for_export(
1250
1250
1251
1251
@register_in_tasks_manager ("clip-text-model" , * ["feature-extraction" ], library_name = "transformers" )
1252
1252
@register_in_tasks_manager ("clip-text-model" , * ["feature-extraction" ], library_name = "diffusers" )
1253
+ @register_in_tasks_manager ("clip-text" , * ["feature-extraction" ], library_name = "diffusers" )
1253
1254
class CLIPTextOpenVINOConfig (CLIPTextOnnxConfig ):
1254
1255
def patch_model_for_export (
1255
1256
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
@@ -1795,12 +1796,31 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1795
1796
)
1796
1797
1797
1798
1799
+ class DummyUnetTimestepInputGenerator (DummyTimestepInputGenerator ):
1800
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1801
+ if input_name != "timestep" :
1802
+ return super ().generate (input_name , framework , int_dtype , float_dtype )
1803
+ shape = [self .batch_size ]
1804
+ return self .random_int_tensor (shape , max_value = self .vocab_size , framework = framework , dtype = int_dtype )
1805
+
1806
+
1798
1807
@register_in_tasks_manager ("unet" , * ["semantic-segmentation" ], library_name = "diffusers" )
1808
+ @register_in_tasks_manager ("unet-2d-condition" , * ["semantic-segmentation" ], library_name = "diffusers" )
1799
1809
class UnetOpenVINOConfig (UNetOnnxConfig ):
1800
- DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator ,) + UNetOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES [1 :]
1810
+ DUMMY_INPUT_GENERATOR_CLASSES = (
1811
+ DummyUnetVisionInputGenerator ,
1812
+ DummyUnetTimestepInputGenerator ,
1813
+ ) + UNetOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES [2 :]
1814
+
1815
+ @property
1816
+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
1817
+ common_inputs = super ().inputs
1818
+ common_inputs ["timestep" ] = {0 : "batch_size" }
1819
+ return common_inputs
1801
1820
1802
1821
1803
1822
@register_in_tasks_manager ("sd3-transformer" , * ["semantic-segmentation" ], library_name = "diffusers" )
1823
+ @register_in_tasks_manager ("sd3-transformer-2d" , * ["semantic-segmentation" ], library_name = "diffusers" )
1804
1824
class SD3TransformerOpenVINOConfig (UNetOnnxConfig ):
1805
1825
DUMMY_INPUT_GENERATOR_CLASSES = (
1806
1826
(DummyTransformerTimestpsInputGenerator ,)
@@ -1830,6 +1850,7 @@ def rename_ambiguous_inputs(self, inputs):
1830
1850
1831
1851
1832
1852
@register_in_tasks_manager ("t5-encoder-model" , * ["feature-extraction" ], library_name = "diffusers" )
1853
+ @register_in_tasks_manager ("t5-encoder" , * ["feature-extraction" ], library_name = "diffusers" )
1833
1854
class T5EncoderOpenVINOConfig (CLIPTextOpenVINOConfig ):
1834
1855
pass
1835
1856
@@ -1905,6 +1926,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1905
1926
1906
1927
1907
1928
@register_in_tasks_manager ("flux-transformer" , * ["semantic-segmentation" ], library_name = "diffusers" )
1929
+ @register_in_tasks_manager ("flux-transformer-2d" , * ["semantic-segmentation" ], library_name = "diffusers" )
1908
1930
class FluxTransformerOpenVINOConfig (SD3TransformerOpenVINOConfig ):
1909
1931
DUMMY_INPUT_GENERATOR_CLASSES = (
1910
1932
DummyTransformerTimestpsInputGenerator ,
0 commit comments