@@ -1256,6 +1256,14 @@ def patch_model_for_export(
1256
1256
) -> ModelPatcher :
1257
1257
return ModelPatcher (self , model , model_kwargs = model_kwargs )
1258
1258
1259
+ def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ):
1260
+ dummy_inputs = super ().generate_dummy_inputs (framework = framework , ** kwargs )
1261
+ # TODO: fix should be by casting inputs during inference and not export
1262
+ if framework == "pt" :
1263
+ import torch
1264
+ dummy_inputs ["input_ids" ] = dummy_inputs ["input_ids" ].to (dtype = torch .int32 )
1265
+ return dummy_inputs
1266
+
1259
1267
1260
1268
@register_in_tasks_manager ("clip-text-with-projection" , * ["feature-extraction" ], library_name = "transformers" )
1261
1269
@register_in_tasks_manager ("clip-text-with-projection" , * ["feature-extraction" ], library_name = "diffusers" )
@@ -1795,9 +1803,17 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1795
1803
)
1796
1804
1797
1805
1806
+ class DummyUnetTimestepInputGenerator (DummyTimestepInputGenerator ):
1807
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1808
+ if input_name != "timestep" :
1809
+ return super ().generate (input_name , framework , int_dtype , float_dtype )
1810
+ shape = [self .batch_size ]
1811
+ return self .random_int_tensor (shape , max_value = self .vocab_size , framework = framework , dtype = int_dtype )
1812
+
1813
+
1798
1814
@register_in_tasks_manager ("unet" , * ["semantic-segmentation" ], library_name = "diffusers" )
1799
1815
class UnetOpenVINOConfig (UNetOnnxConfig ):
1800
- DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator ,) + UNetOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES [1 :]
1816
+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator , DummyUnetTimestepInputGenerator ) + UNetOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES [2 :]
1801
1817
1802
1818
1803
1819
@register_in_tasks_manager ("sd3-transformer" , * ["semantic-segmentation" ], library_name = "diffusers" )
0 commit comments