@@ -1864,18 +1864,49 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1864
1864
return self .random_int_tensor (shape , max_value = self .vocab_size , framework = framework , dtype = int_dtype )
1865
1865
1866
1866
1867
+ class DummyUnetEncoderInputGenerator (DummySeq2SeqDecoderTextInputGenerator ):
1868
+ def __init__ (
1869
+ self ,
1870
+ task : str ,
1871
+ normalized_config : NormalizedTextConfig ,
1872
+ batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
1873
+ sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
1874
+ num_choices : int = DEFAULT_DUMMY_SHAPES ["num_choices" ],
1875
+ random_batch_size_range : Optional [Tuple [int , int ]] = None ,
1876
+ random_sequence_length_range : Optional [Tuple [int , int ]] = None ,
1877
+ random_num_choices_range : Optional [Tuple [int , int ]] = None ,
1878
+ ** kwargs ,
1879
+ ):
1880
+ super ().__init__ (
1881
+ task ,
1882
+ normalized_config ,
1883
+ batch_size = batch_size ,
1884
+ sequence_length = sequence_length ,
1885
+ num_choices = num_choices ,
1886
+ random_batch_size_range = random_batch_size_range ,
1887
+ random_sequence_length_range = random_sequence_length_range ,
1888
+ random_num_choices_range = random_num_choices_range ,
1889
+ ** kwargs ,
1890
+ )
1891
+ if hasattr (normalized_config .config , "model_max_length" ):
1892
+ self .sequence_length = normalized_config .config .model_max_length
1893
+
1894
+
1867
1895
@register_in_tasks_manager ("unet" , * ["semantic-segmentation" ], library_name = "diffusers" )
1868
1896
@register_in_tasks_manager ("unet-2d-condition" , * ["semantic-segmentation" ], library_name = "diffusers" )
1869
1897
class UNetOpenVINOConfig (UNetOnnxConfig ):
1870
1898
DUMMY_INPUT_GENERATOR_CLASSES = (
1871
1899
DummyUnetVisionInputGenerator ,
1872
1900
DummyUnetTimestepInputGenerator ,
1873
- ) + UNetOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES [2 :]
1901
+ DummyUnetEncoderInputGenerator ,
1902
+ )
1874
1903
1875
1904
@property
1876
1905
def inputs (self ) -> Dict [str , Dict [int , str ]]:
1877
1906
common_inputs = super ().inputs
1878
1907
common_inputs ["timestep" ] = {0 : "batch_size" }
1908
+ if hasattr (self ._normalized_config .config , "model_max_length" ):
1909
+ common_inputs ["encoder_hidden_states" ] = {0 : "batch_size" }
1879
1910
return common_inputs
1880
1911
1881
1912
0 commit comments