diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index b3dc7e053c..cc60ec8310 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -98,6 +98,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - UniSpeech SAT - Vision Encoder Decoder - Vit +- VitMatte - Wav2Vec2 - Wav2Vec2 Conformer - WavLM diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 8cd94194ff..19ab7fab9b 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -151,6 +151,7 @@ class OnnxConfig(ExportConfig, ABC): "feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}), "fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-classification": OrderedDict({"logits": {0: "batch_size"}}), + "image-matting": OrderedDict({"alphas": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}), "image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-to-image": OrderedDict( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cd9d54eeca..5b42d4dd2c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -756,6 +756,40 @@ class Swin2srOnnxConfig(SwinOnnxConfig): pass +class VitMatteDummyInputGenerator(DummyVisionInputGenerator): + + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = DEFAULT_DUMMY_SHAPES["width"], + height: int = DEFAULT_DUMMY_SHAPES["height"], + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + num_channels=normalized_config.backbone_config.num_channels, + width=width, + height=height, + **kwargs, + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + input_ = super().generate( + input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype + ) + return input_ + + +class VitMatteOnnxConfig(ViTOnnxConfig): + DEFAULT_ONNX_OPSET = 12 + DUMMY_INPUT_GENERATOR_CLASSES = (VitMatteDummyInputGenerator,) + + class DptOnnxConfig(ViTOnnxConfig): pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7ccb0d9c7b..1e35defb5c 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -168,6 +168,7 @@ class TasksManager: "feature-extraction": "AutoModel", "fill-mask": "AutoModelForMaskedLM", "image-classification": "AutoModelForImageClassification", + "image-matting": "VitMatteForImageMatting", "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), "image-to-image": "AutoModelForImageToImage", "image-to-text": "AutoModelForVision2Seq", @@ -1031,6 +1032,11 @@ class TasksManager: "vit": supported_tasks_mapping( "feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig" ), + "vitmatte": supported_tasks_mapping( + "feature-extraction", + "image-matting", + onnx="VitMatteOnnxConfig", + ), "wavlm": supported_tasks_mapping( "feature-extraction", "automatic-speech-recognition",