Skip to content

Commit 805335c

Browse files
authored
make unet and vae number of channels static (#692)
* make unet and vae number of channels static * resolve issue with pytest 8.2
1 parent d23ab0a commit 805335c

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

optimum/exporters/openvino/model_configs.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, PhiOnnxConfig
22+
from optimum.exporters.onnx.model_configs import (
23+
FalconOnnxConfig,
24+
GemmaOnnxConfig,
25+
LlamaOnnxConfig,
26+
PhiOnnxConfig,
27+
UNetOnnxConfig,
28+
VaeDecoderOnnxConfig,
29+
VaeEncoderOnnxConfig,
30+
)
2331
from optimum.exporters.tasks import TasksManager
2432
from optimum.utils import DEFAULT_DUMMY_SHAPES
2533
from optimum.utils.input_generators import (
@@ -510,3 +518,59 @@ class FalconOpenVINOConfig(FalconOnnxConfig):
510518
OVFalconDummyPastKeyValuesGenerator,
511519
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
512520
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator
521+
522+
523+
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
524+
class UNetOpenVINOConfig(UNetOnnxConfig):
525+
@property
526+
def inputs(self) -> Dict[str, Dict[int, str]]:
527+
common_inputs = {
528+
"sample": {0: "batch_size", 2: "height", 3: "width"},
529+
"timestep": {0: "steps"},
530+
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
531+
}
532+
533+
# TODO : add text_image, image and image_embeds
534+
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
535+
common_inputs["text_embeds"] = {0: "batch_size"}
536+
common_inputs["time_ids"] = {0: "batch_size"}
537+
538+
if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
539+
common_inputs["timestep_cond"] = {0: "batch_size"}
540+
return common_inputs
541+
542+
@property
543+
def outputs(self) -> Dict[str, Dict[int, str]]:
544+
return {
545+
"out_sample": {0: "batch_size", 2: "height", 3: "width"},
546+
}
547+
548+
549+
@register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers")
550+
class VaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig):
551+
@property
552+
def inputs(self) -> Dict[str, Dict[int, str]]:
553+
return {
554+
"sample": {0: "batch_size", 2: "height", 3: "width"},
555+
}
556+
557+
@property
558+
def outputs(self) -> Dict[str, Dict[int, str]]:
559+
return {
560+
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
561+
}
562+
563+
564+
@register_in_tasks_manager("vae-decoder", *["semantic-segmentation"], library_name="diffusers")
565+
class VaeDecoderOpenVINOConfig(VaeDecoderOnnxConfig):
566+
@property
567+
def inputs(self) -> Dict[str, Dict[int, str]]:
568+
return {
569+
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
570+
}
571+
572+
@property
573+
def outputs(self) -> Dict[str, Dict[int, str]]:
574+
return {
575+
"sample": {0: "batch_size", 2: "height", 3: "width"},
576+
}

0 commit comments

Comments
 (0)