@@ -1678,43 +1678,17 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1678
1678
shape = [self .batch_size , (self .height // 2 ) * (self .width // 2 ), self .num_channels * 4 ]
1679
1679
return self .random_float_tensor (shape , framework = framework , dtype = float_dtype )
1680
1680
if input_name == "img_ids" :
1681
- return self .prepare_image_ids (framework , int_dtype , float_dtype )
1682
-
1683
- return super ().generate (input_name , framework , int_dtype , float_dtype )
1684
-
1685
- def prepare_image_ids (self , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1686
- img_ids_height = self .height // 2
1687
- img_ids_width = self .width // 2
1688
- if framework == "pt" :
1689
- import torch
1690
-
1691
- latent_image_ids = torch .zeros (img_ids_height , img_ids_width , 3 )
1692
- latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (img_ids_height )[:, None ]
1693
- latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (img_ids_width )[None , :]
1694
-
1695
- latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
1696
-
1697
- latent_image_ids = latent_image_ids [None , :].repeat (self .batch_size , 1 , 1 , 1 )
1698
- latent_image_ids = latent_image_ids .reshape (
1699
- self .batch_size , latent_image_id_height * latent_image_id_width , latent_image_id_channels
1681
+ img_ids_height = self .height // 2
1682
+ img_ids_width = self .width // 2
1683
+ return self .random_int_tensor (
1684
+ [self .batch_size , img_ids_height * img_ids_width , 3 ],
1685
+ min_value = 0 ,
1686
+ max_value = min (img_ids_height , img_ids_width ),
1687
+ framework = framework ,
1688
+ dtype = float_dtype ,
1700
1689
)
1701
- latent_image_ids .to (DTYPE_MAPPER .pt (float_dtype ))
1702
- return latent_image_ids
1703
- if framework == "np" :
1704
- import numpy as np
1705
1690
1706
- latent_image_ids = np .zeros (img_ids_height , img_ids_width , 3 )
1707
- latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + np .arange (img_ids_height )[:, None ]
1708
- latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + np .arange (img_ids_width )[None , :]
1709
-
1710
- latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
1711
-
1712
- latent_image_ids = np .tile (latent_image_ids [None , :], (self .batch_size , 1 , 1 , 1 ))
1713
- latent_image_ids = latent_image_ids .reshape (
1714
- self .batch_size , latent_image_id_height * latent_image_id_width , latent_image_id_channels
1715
- )
1716
- latent_image_ids .astype (DTYPE_MAPPER .np [float_dtype ])
1717
- return latent_image_ids
1691
+ return super ().generate (input_name , framework , int_dtype , float_dtype )
1718
1692
1719
1693
1720
1694
class DummyFluxTextInputGenerator (DummySeq2SeqDecoderTextInputGenerator ):
@@ -1728,7 +1702,11 @@ class DummyFluxTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
1728
1702
1729
1703
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1730
1704
if input_name == "txt_ids" :
1731
- return self .constant_tensor ([self .batch_size , self .sequence_length , 3 ], 0 , DTYPE_MAPPER .pt (float_dtype ))
1705
+ import torch
1706
+
1707
+ shape = [self .batch_size , self .sequence_length , 3 ]
1708
+ dtype = DTYPE_MAPPER .pt (float_dtype )
1709
+ return torch .full (shape , 0 , dtype = dtype )
1732
1710
return super ().generate (input_name , framework , int_dtype , float_dtype )
1733
1711
1734
1712
0 commit comments