@@ -897,23 +897,31 @@ def __init__(
897
897
):
898
898
self .task = task
899
899
self .vocab_size = normalized_config .vocab_size
900
- self .text_encoder_projection_dim = normalized_config . text_encoder_projection_dim
901
- self .time_ids = 5 if normalized_config . requires_aesthetics_score else 6
900
+ self .text_encoder_projection_dim = getattr ( normalized_config , " text_encoder_projection_dim" , None )
901
+ self .time_ids = 5 if getattr ( normalized_config , " requires_aesthetics_score" , False ) else 6
902
902
if random_batch_size_range :
903
903
low , high = random_batch_size_range
904
904
self .batch_size = random .randint (low , high )
905
905
else :
906
906
self .batch_size = batch_size
907
- self .time_cond_proj_dim = normalized_config .config . time_cond_proj_dim
907
+ self .time_cond_proj_dim = getattr ( normalized_config .config , " time_cond_proj_dim" , None )
908
908
909
909
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
910
910
if input_name == "timestep" :
911
911
shape = [] # a scalar with no dimension (it can be int or float depending on the sd architecture)
912
912
return self .random_float_tensor (shape , max_value = self .vocab_size , framework = framework , dtype = float_dtype )
913
913
914
914
if input_name == "text_embeds" :
915
+ if self .text_encoder_projection_dim is None :
916
+ raise ValueError (
917
+ "Unable to infer the value of `text_encoder_projection_dim` for generating `text_embeds`, please double check the config of your model."
918
+ )
915
919
dim = self .text_encoder_projection_dim
916
920
elif input_name == "timestep_cond" :
921
+ if self .time_cond_proj_dim is None :
922
+ raise ValueError (
923
+ "Unable to infer the value of `time_cond_proj_dim` for generating `timestep_cond`, please double check the config of your model."
924
+ )
917
925
dim = self .time_cond_proj_dim
918
926
else :
919
927
dim = self .time_ids
0 commit comments