Skip to content

Commit 0245619

Browse files
ariG23498xenova
andauthored
Add ViTPose ONNX export (#2183)
* Add ONNX export support for ViTPose * building dummy inputs for vit post * Move vitpose config to custom class * Move input generators * Patch VitPose models with num_experts>1 * Formatting * Add vitpose export unit tests --------- Co-authored-by: Joshua Lochner <admin@xenova.com>
1 parent 414afab commit 0245619

File tree

7 files changed

+123
-70
lines changed

7 files changed

+123
-70
lines changed

optimum/exporters/onnx/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ class OnnxConfig(ExportConfig, ABC):
159159
"image-to-image": OrderedDict(
160160
{"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
161161
),
162+
"keypoint-detection": OrderedDict(
163+
{"heatmaps": {0: "batch_size", 1: "num_keypoints", 2: "height", 3: "width"}}
164+
),
162165
"mask-generation": OrderedDict({"logits": {0: "batch_size"}}),
163166
"masked-im": OrderedDict(
164167
{"reconstruction" if is_transformers_version(">=", "4.29.0") else "logits": {0: "batch_size"}}

optimum/exporters/onnx/model_configs.py

+20-70
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...utils import (
2525
DEFAULT_DUMMY_SHAPES,
2626
BloomDummyPastKeyValuesGenerator,
27+
Dinov2DummyInputGenerator,
2728
DummyAudioInputGenerator,
2829
DummyCodegenDecoderTextInputGenerator,
2930
DummyDecisionTransformerInputGenerator,
@@ -63,6 +64,8 @@
6364
NormalizedTextConfigWithGQA,
6465
NormalizedTimeSeriesForecastingConfig,
6566
NormalizedVisionConfig,
67+
PerceiverDummyInputGenerator,
68+
VitPoseDummyInputGenerator,
6669
is_diffusers_available,
6770
is_diffusers_version,
6871
is_transformers_version,
@@ -93,6 +96,7 @@
9396
SentenceTransformersTransformerPatcher,
9497
SpeechT5ModelPatcher,
9598
VisionEncoderDecoderPatcher,
99+
VitPoseModelPatcher,
96100
WavLMModelPatcher,
97101
)
98102

@@ -847,6 +851,22 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
847851
return common_outputs
848852

849853

854+
class VitPoseOnnxConfig(ViTOnnxConfig):
855+
DUMMY_INPUT_GENERATOR_CLASSES = (VitPoseDummyInputGenerator,)
856+
ATOL_FOR_VALIDATION = 1e-4
857+
858+
@property
859+
def inputs(self) -> Dict[str, Dict[int, str]]:
860+
return {"pixel_values": {0: "batch_size"}}
861+
862+
# Some VitPose models use multiple experts, which requires dataset_index to be provided.
863+
# So, we need to patch the model for export to provide the dataset_index.
864+
def patch_model_for_export(
865+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
866+
) -> "ModelPatcher":
867+
return VitPoseModelPatcher(self, model, model_kwargs=model_kwargs)
868+
869+
850870
class CvTOnnxConfig(ViTOnnxConfig):
851871
DEFAULT_ONNX_OPSET = 13
852872
ATOL_FOR_VALIDATION = 1e-2
@@ -892,41 +912,6 @@ class VitMSNOnnxConfig(ViTOnnxConfig):
892912
DEFAULT_ONNX_OPSET = 14
893913

894914

895-
class Dinov2DummyInputGenerator(DummyVisionInputGenerator):
896-
def __init__(
897-
self,
898-
task: str,
899-
normalized_config: NormalizedVisionConfig,
900-
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
901-
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
902-
width: int = DEFAULT_DUMMY_SHAPES["width"],
903-
height: int = DEFAULT_DUMMY_SHAPES["height"],
904-
**kwargs,
905-
):
906-
super().__init__(
907-
task=task,
908-
normalized_config=normalized_config,
909-
batch_size=batch_size,
910-
num_channels=num_channels,
911-
width=width,
912-
height=height,
913-
**kwargs,
914-
)
915-
916-
from transformers.onnx.utils import get_preprocessor
917-
918-
preprocessor = get_preprocessor(normalized_config._name_or_path)
919-
if preprocessor is not None and hasattr(preprocessor, "crop_size"):
920-
self.height = preprocessor.crop_size.get("height", self.height)
921-
self.width = preprocessor.crop_size.get("width", self.width)
922-
923-
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
924-
input_ = super().generate(
925-
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
926-
)
927-
return input_
928-
929-
930915
class Dinov2OnnxConfig(ViTOnnxConfig):
931916
DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,)
932917

@@ -1606,41 +1591,6 @@ class Data2VecAudioOnnxConfig(AudioOnnxConfig):
16061591
NORMALIZED_CONFIG_CLASS = NormalizedConfig
16071592

16081593

1609-
class PerceiverDummyInputGenerator(DummyVisionInputGenerator):
1610-
def __init__(
1611-
self,
1612-
task: str,
1613-
normalized_config: NormalizedVisionConfig,
1614-
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1615-
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
1616-
width: int = DEFAULT_DUMMY_SHAPES["width"],
1617-
height: int = DEFAULT_DUMMY_SHAPES["height"],
1618-
**kwargs,
1619-
):
1620-
super().__init__(
1621-
task=task,
1622-
normalized_config=normalized_config,
1623-
batch_size=batch_size,
1624-
num_channels=num_channels,
1625-
width=width,
1626-
height=height,
1627-
**kwargs,
1628-
)
1629-
1630-
from transformers.onnx.utils import get_preprocessor
1631-
1632-
preprocessor = get_preprocessor(normalized_config._name_or_path)
1633-
if preprocessor is not None and hasattr(preprocessor, "size"):
1634-
self.height = preprocessor.size.get("height", self.height)
1635-
self.width = preprocessor.size.get("width", self.width)
1636-
1637-
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1638-
input_ = super().generate(
1639-
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
1640-
)
1641-
return input_
1642-
1643-
16441594
class PerceiverOnnxConfig(TextAndVisionOnnxConfig):
16451595
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
16461596
DUMMY_INPUT_GENERATOR_CLASSES = (

optimum/exporters/onnx/model_patcher.py

+15
Original file line numberDiff line numberDiff line change
@@ -1338,3 +1338,18 @@ def __exit__(self, exc_type, exc_value, traceback):
13381338
super().__exit__(exc_type, exc_value, traceback)
13391339
if is_transformers_version(">=", "4.43"):
13401340
CLIPSdpaAttention.forward = self.original_sdpa_forward
1341+
1342+
1343+
class VitPoseModelPatcher(ModelPatcher):
1344+
def __init__(
1345+
self,
1346+
config: "OnnxConfig",
1347+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
1348+
model_kwargs: Optional[Dict[str, Any]] = None,
1349+
):
1350+
# Set dataset_index (defaulting to COCO=0), otherwise we will get an error like:
1351+
# ValueError: dataset_index must be provided when using multiple experts (num_experts=6). Please provide dataset_index to the forward pass.
1352+
if model.config.backbone_config.num_experts > 1:
1353+
model_kwargs["dataset_index"] = torch.tensor(0, device=model.device)
1354+
1355+
super().__init__(config, model, model_kwargs)

optimum/exporters/tasks.py

+2
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ class TasksManager:
329329
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
330330
# VisionEncoderDecoderModel is not registered in AutoModelForDocumentQuestionAnswering
331331
("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"),
332+
("pt", "vitpose", "keypoint-detection"): ("transformers", "VitPoseForPoseEstimation"),
332333
}
333334

334335
_ENCODER_DECODER_TASKS = (
@@ -1241,6 +1242,7 @@ class TasksManager:
12411242
"image-classification",
12421243
onnx="VitMSNOnnxConfig",
12431244
),
1245+
"vitpose": supported_tasks_mapping("feature-extraction", "keypoint-detection", onnx="VitPoseOnnxConfig"),
12441246
"vits": supported_tasks_mapping(
12451247
"text-to-audio",
12461248
onnx="VitsOnnxConfig",

optimum/utils/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
DEFAULT_DUMMY_SHAPES,
5757
DTYPE_MAPPER,
5858
BloomDummyPastKeyValuesGenerator,
59+
Dinov2DummyInputGenerator,
5960
DummyAudioInputGenerator,
6061
DummyBboxInputGenerator,
6162
DummyCodegenDecoderTextInputGenerator,
@@ -90,6 +91,8 @@
9091
MCTCTDummyAudioInputGenerator,
9192
MistralDummyPastKeyValuesGenerator,
9293
MultiQueryPastKeyValuesGenerator,
94+
PerceiverDummyInputGenerator,
95+
VitPoseDummyInputGenerator,
9396
)
9497
from .modeling_utils import recurse_getattr, recurse_setattr
9598
from .normalized_config import (

optimum/utils/input_generators.py

+78
Original file line numberDiff line numberDiff line change
@@ -1592,3 +1592,81 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
15921592
return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework, dtype=float_dtype)
15931593

15941594
return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype)
1595+
1596+
1597+
class Dinov2DummyInputGenerator(DummyVisionInputGenerator):
1598+
def __init__(
1599+
self,
1600+
task: str,
1601+
normalized_config: NormalizedVisionConfig,
1602+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1603+
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
1604+
width: int = DEFAULT_DUMMY_SHAPES["width"],
1605+
height: int = DEFAULT_DUMMY_SHAPES["height"],
1606+
**kwargs,
1607+
):
1608+
super().__init__(
1609+
task=task,
1610+
normalized_config=normalized_config,
1611+
batch_size=batch_size,
1612+
num_channels=num_channels,
1613+
width=width,
1614+
height=height,
1615+
**kwargs,
1616+
)
1617+
1618+
from transformers.onnx.utils import get_preprocessor
1619+
1620+
preprocessor = get_preprocessor(normalized_config._name_or_path)
1621+
if preprocessor is not None and hasattr(preprocessor, "crop_size"):
1622+
self.height = preprocessor.crop_size.get("height", self.height)
1623+
self.width = preprocessor.crop_size.get("width", self.width)
1624+
1625+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1626+
input_ = super().generate(
1627+
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
1628+
)
1629+
return input_
1630+
1631+
1632+
class DummyVisionStaticInputGenerator(DummyVisionInputGenerator):
1633+
def __init__(
1634+
self,
1635+
task: str,
1636+
normalized_config: NormalizedVisionConfig,
1637+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1638+
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
1639+
width: int = DEFAULT_DUMMY_SHAPES["width"],
1640+
height: int = DEFAULT_DUMMY_SHAPES["height"],
1641+
**kwargs,
1642+
):
1643+
super().__init__(
1644+
task=task,
1645+
normalized_config=normalized_config,
1646+
batch_size=batch_size,
1647+
num_channels=num_channels,
1648+
width=width,
1649+
height=height,
1650+
**kwargs,
1651+
)
1652+
1653+
from transformers.onnx.utils import get_preprocessor
1654+
1655+
preprocessor = get_preprocessor(normalized_config._name_or_path)
1656+
if preprocessor is not None and hasattr(preprocessor, "size"):
1657+
self.height = preprocessor.size.get("height", self.height)
1658+
self.width = preprocessor.size.get("width", self.width)
1659+
1660+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1661+
input_ = super().generate(
1662+
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
1663+
)
1664+
return input_
1665+
1666+
1667+
class PerceiverDummyInputGenerator(DummyVisionStaticInputGenerator):
1668+
pass
1669+
1670+
1671+
class VitPoseDummyInputGenerator(DummyVisionStaticInputGenerator):
1672+
pass

tests/exporters/exporters_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
"vit-mae": "hf-internal-testing/tiny-random-ViTMAEModel",
175175
"vit-msn": "hf-internal-testing/tiny-random-ViTMSNForImageClassification",
176176
"vits": "echarlaix/tiny-random-vits",
177+
"vitpose": "hf-internal-testing/tiny-random-VitPoseForPoseEstimation",
177178
"yolos": "hf-internal-testing/tiny-random-YolosModel",
178179
"whisper": "optimum-internal-testing/tiny-random-whisper",
179180
"hubert": "hf-internal-testing/tiny-random-HubertModel",
@@ -299,6 +300,7 @@
299300
"vit": "google/vit-base-patch16-224",
300301
"vit-mae": "facebook/vit-mae-base",
301302
"vit-msn": "facebook/vit-msn-small",
303+
"vitpose": "usyd-community/vitpose-plus-small",
302304
"yolos": "hustvl/yolos-tiny",
303305
"whisper": "openai/whisper-tiny.en",
304306
"hubert": "facebook/hubert-base-ls960",

0 commit comments

Comments
 (0)