Skip to content

Commit 8361e45

Browse files
authored
Avoid extra reshaping to max_model_lenght for unet (#1164)
1 parent 82bb652 commit 8361e45

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

optimum/exporters/openvino/convert.py

+3
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,9 @@ def get_diffusion_models_for_export_ext(
10251025
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")
10261026

10271027
if is_sd or is_sdxl or is_lcm:
1028+
tokenizer = pipeline.tokenizer_2 if is_sdxl else pipeline.tokenizer
1029+
model_max_length = getattr(tokenizer, "model_max_length", None)
1030+
pipeline.unet.config.model_max_length = model_max_length
10281031
models_for_export = get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
10291032
if is_sdxl and pipeline.vae.config.force_upcast:
10301033
models_for_export["vae_encoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "128.0"}

optimum/exporters/openvino/model_configs.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1864,18 +1864,49 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
18641864
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype)
18651865

18661866

1867+
class DummyUnetEncoderInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
1868+
def __init__(
1869+
self,
1870+
task: str,
1871+
normalized_config: NormalizedTextConfig,
1872+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1873+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
1874+
num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"],
1875+
random_batch_size_range: Optional[Tuple[int, int]] = None,
1876+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
1877+
random_num_choices_range: Optional[Tuple[int, int]] = None,
1878+
**kwargs,
1879+
):
1880+
super().__init__(
1881+
task,
1882+
normalized_config,
1883+
batch_size=batch_size,
1884+
sequence_length=sequence_length,
1885+
num_choices=num_choices,
1886+
random_batch_size_range=random_batch_size_range,
1887+
random_sequence_length_range=random_sequence_length_range,
1888+
random_num_choices_range=random_num_choices_range,
1889+
**kwargs,
1890+
)
1891+
if hasattr(normalized_config.config, "model_max_length"):
1892+
self.sequence_length = normalized_config.config.model_max_length
1893+
1894+
18671895
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
18681896
@register_in_tasks_manager("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers")
18691897
class UNetOpenVINOConfig(UNetOnnxConfig):
18701898
DUMMY_INPUT_GENERATOR_CLASSES = (
18711899
DummyUnetVisionInputGenerator,
18721900
DummyUnetTimestepInputGenerator,
1873-
) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[2:]
1901+
DummyUnetEncoderInputGenerator,
1902+
)
18741903

18751904
@property
18761905
def inputs(self) -> Dict[str, Dict[int, str]]:
18771906
common_inputs = super().inputs
18781907
common_inputs["timestep"] = {0: "batch_size"}
1908+
if hasattr(self._normalized_config.config, "model_max_length"):
1909+
common_inputs["encoder_hidden_states"] = {0: "batch_size"}
18791910
return common_inputs
18801911

18811912

0 commit comments

Comments
 (0)