|
19 | 19 | from transformers.utils import is_tf_available
|
20 | 20 |
|
21 | 21 | from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
|
22 |
| -from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, PhiOnnxConfig |
| 22 | +from optimum.exporters.onnx.model_configs import ( |
| 23 | + FalconOnnxConfig, |
| 24 | + GemmaOnnxConfig, |
| 25 | + LlamaOnnxConfig, |
| 26 | + PhiOnnxConfig, |
| 27 | + UNetOnnxConfig, |
| 28 | + VaeDecoderOnnxConfig, |
| 29 | + VaeEncoderOnnxConfig, |
| 30 | +) |
23 | 31 | from optimum.exporters.tasks import TasksManager
|
24 | 32 | from optimum.utils import DEFAULT_DUMMY_SHAPES
|
25 | 33 | from optimum.utils.input_generators import (
|
@@ -510,3 +518,59 @@ class FalconOpenVINOConfig(FalconOnnxConfig):
|
510 | 518 | OVFalconDummyPastKeyValuesGenerator,
|
511 | 519 | ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
|
512 | 520 | DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator
|
| 521 | + |
| 522 | + |
| 523 | +@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") |
| 524 | +class UNetOpenVINOConfig(UNetOnnxConfig): |
| 525 | + @property |
| 526 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 527 | + common_inputs = { |
| 528 | + "sample": {0: "batch_size", 2: "height", 3: "width"}, |
| 529 | + "timestep": {0: "steps"}, |
| 530 | + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, |
| 531 | + } |
| 532 | + |
| 533 | + # TODO : add text_image, image and image_embeds |
| 534 | + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": |
| 535 | + common_inputs["text_embeds"] = {0: "batch_size"} |
| 536 | + common_inputs["time_ids"] = {0: "batch_size"} |
| 537 | + |
| 538 | + if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: |
| 539 | + common_inputs["timestep_cond"] = {0: "batch_size"} |
| 540 | + return common_inputs |
| 541 | + |
| 542 | + @property |
| 543 | + def outputs(self) -> Dict[str, Dict[int, str]]: |
| 544 | + return { |
| 545 | + "out_sample": {0: "batch_size", 2: "height", 3: "width"}, |
| 546 | + } |
| 547 | + |
| 548 | + |
| 549 | +@register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers") |
| 550 | +class VaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig): |
| 551 | + @property |
| 552 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 553 | + return { |
| 554 | + "sample": {0: "batch_size", 2: "height", 3: "width"}, |
| 555 | + } |
| 556 | + |
| 557 | + @property |
| 558 | + def outputs(self) -> Dict[str, Dict[int, str]]: |
| 559 | + return { |
| 560 | + "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, |
| 561 | + } |
| 562 | + |
| 563 | + |
| 564 | +@register_in_tasks_manager("vae-decoder", *["semantic-segmentation"], library_name="diffusers") |
| 565 | +class VaeDecoderOpenVINOConfig(VaeDecoderOnnxConfig): |
| 566 | + @property |
| 567 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 568 | + return { |
| 569 | + "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, |
| 570 | + } |
| 571 | + |
| 572 | + @property |
| 573 | + def outputs(self) -> Dict[str, Dict[int, str]]: |
| 574 | + return { |
| 575 | + "sample": {0: "batch_size", 2: "height", 3: "width"}, |
| 576 | + } |
0 commit comments