Skip to content

Commit a6c696c

Browse files
Generate guidance for flux (#2104)
generate guidance
1 parent 65a8a94 commit a6c696c

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

optimum/onnxruntime/modeling_diffusion.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,21 @@ def to(self, device: Union[torch.device, str, int]):
437437
def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs):
438438
return cls.load_config(config_name_or_path, **kwargs)
439439

440-
def _save_config(self, save_directory):
441-
self.save_config(save_directory)
440+
def _save_config(self, save_directory: Union[str, Path]):
441+
model_dir = (
442+
self.model_save_dir
443+
if not isinstance(self.model_save_dir, TemporaryDirectory)
444+
else self.model_save_dir.name
445+
)
446+
save_dir = Path(save_directory)
447+
original_config = Path(model_dir) / self.config_name
448+
if original_config.exists():
449+
if not save_dir.exists():
450+
save_dir.mkdir(parents=True)
451+
452+
shutil.copy(original_config, save_dir)
453+
else:
454+
self.save_config(save_directory)
442455

443456
@property
444457
def components(self) -> Dict[str, Any]:

optimum/utils/input_generators.py

+4
Original file line numberDiff line numberDiff line change
@@ -1508,6 +1508,7 @@ class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator)
15081508
SUPPORTED_INPUT_NAMES = (
15091509
"encoder_hidden_states",
15101510
"pooled_projections",
1511+
"guidance",
15111512
"txt_ids",
15121513
)
15131514

@@ -1519,5 +1520,8 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
15191520
else [self.batch_size, self.sequence_length, 3]
15201521
)
15211522
return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype)
1523+
elif input_name == "guidance":
1524+
shape = [self.batch_size]
1525+
return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype)
15221526

15231527
return super().generate(input_name, framework, int_dtype, float_dtype)

0 commit comments

Comments
 (0)