Skip to content

Commit bd08f12

Browse files
authored
Unbundle inputs generated by DummyTimestepInputGenerator (#2107)
unbundle
1 parent a6c696c commit bd08f12

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

optimum/utils/input_generators.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -897,23 +897,31 @@ def __init__(
897897
):
898898
self.task = task
899899
self.vocab_size = normalized_config.vocab_size
900-
self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim
901-
self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6
900+
self.text_encoder_projection_dim = getattr(normalized_config, "text_encoder_projection_dim", None)
901+
self.time_ids = 5 if getattr(normalized_config, "requires_aesthetics_score", False) else 6
902902
if random_batch_size_range:
903903
low, high = random_batch_size_range
904904
self.batch_size = random.randint(low, high)
905905
else:
906906
self.batch_size = batch_size
907-
self.time_cond_proj_dim = normalized_config.config.time_cond_proj_dim
907+
self.time_cond_proj_dim = getattr(normalized_config.config, "time_cond_proj_dim", None)
908908

909909
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
910910
if input_name == "timestep":
911911
shape = [] # a scalar with no dimension (it can be int or float depending on the sd architecture)
912912
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
913913

914914
if input_name == "text_embeds":
915+
if self.text_encoder_projection_dim is None:
916+
raise ValueError(
917+
"Unable to infer the value of `text_encoder_projection_dim` for generating `text_embeds`, please double check the config of your model."
918+
)
915919
dim = self.text_encoder_projection_dim
916920
elif input_name == "timestep_cond":
921+
if self.time_cond_proj_dim is None:
922+
raise ValueError(
923+
"Unable to infer the value of `time_cond_proj_dim` for generating `timestep_cond`, please double check the config of your model."
924+
)
917925
dim = self.time_cond_proj_dim
918926
else:
919927
dim = self.time_ids

0 commit comments

Comments
 (0)