|
24 | 24 | from ...utils import (
|
25 | 25 | DEFAULT_DUMMY_SHAPES,
|
26 | 26 | BloomDummyPastKeyValuesGenerator,
|
| 27 | + Dinov2DummyInputGenerator, |
27 | 28 | DummyAudioInputGenerator,
|
28 | 29 | DummyCodegenDecoderTextInputGenerator,
|
29 | 30 | DummyDecisionTransformerInputGenerator,
|
|
63 | 64 | NormalizedTextConfigWithGQA,
|
64 | 65 | NormalizedTimeSeriesForecastingConfig,
|
65 | 66 | NormalizedVisionConfig,
|
| 67 | + PerceiverDummyInputGenerator, |
| 68 | + VitPoseDummyInputGenerator, |
66 | 69 | is_diffusers_available,
|
67 | 70 | is_diffusers_version,
|
68 | 71 | is_transformers_version,
|
|
93 | 96 | SentenceTransformersTransformerPatcher,
|
94 | 97 | SpeechT5ModelPatcher,
|
95 | 98 | VisionEncoderDecoderPatcher,
|
| 99 | + VitPoseModelPatcher, |
96 | 100 | WavLMModelPatcher,
|
97 | 101 | )
|
98 | 102 |
|
@@ -847,6 +851,22 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
|
847 | 851 | return common_outputs
|
848 | 852 |
|
849 | 853 |
|
| 854 | +class VitPoseOnnxConfig(ViTOnnxConfig): |
| 855 | + DUMMY_INPUT_GENERATOR_CLASSES = (VitPoseDummyInputGenerator,) |
| 856 | + ATOL_FOR_VALIDATION = 1e-4 |
| 857 | + |
| 858 | + @property |
| 859 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 860 | + return {"pixel_values": {0: "batch_size"}} |
| 861 | + |
| 862 | + # Some VitPose models use multiple experts, which requires dataset_index to be provided. |
| 863 | + # So, we need to patch the model for export to provide the dataset_index. |
| 864 | + def patch_model_for_export( |
| 865 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 866 | + ) -> "ModelPatcher": |
| 867 | + return VitPoseModelPatcher(self, model, model_kwargs=model_kwargs) |
| 868 | + |
| 869 | + |
850 | 870 | class CvTOnnxConfig(ViTOnnxConfig):
|
851 | 871 | DEFAULT_ONNX_OPSET = 13
|
852 | 872 | ATOL_FOR_VALIDATION = 1e-2
|
@@ -892,41 +912,6 @@ class VitMSNOnnxConfig(ViTOnnxConfig):
|
892 | 912 | DEFAULT_ONNX_OPSET = 14
|
893 | 913 |
|
894 | 914 |
|
895 |
| -class Dinov2DummyInputGenerator(DummyVisionInputGenerator): |
896 |
| - def __init__( |
897 |
| - self, |
898 |
| - task: str, |
899 |
| - normalized_config: NormalizedVisionConfig, |
900 |
| - batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
901 |
| - num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], |
902 |
| - width: int = DEFAULT_DUMMY_SHAPES["width"], |
903 |
| - height: int = DEFAULT_DUMMY_SHAPES["height"], |
904 |
| - **kwargs, |
905 |
| - ): |
906 |
| - super().__init__( |
907 |
| - task=task, |
908 |
| - normalized_config=normalized_config, |
909 |
| - batch_size=batch_size, |
910 |
| - num_channels=num_channels, |
911 |
| - width=width, |
912 |
| - height=height, |
913 |
| - **kwargs, |
914 |
| - ) |
915 |
| - |
916 |
| - from transformers.onnx.utils import get_preprocessor |
917 |
| - |
918 |
| - preprocessor = get_preprocessor(normalized_config._name_or_path) |
919 |
| - if preprocessor is not None and hasattr(preprocessor, "crop_size"): |
920 |
| - self.height = preprocessor.crop_size.get("height", self.height) |
921 |
| - self.width = preprocessor.crop_size.get("width", self.width) |
922 |
| - |
923 |
| - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
924 |
| - input_ = super().generate( |
925 |
| - input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype |
926 |
| - ) |
927 |
| - return input_ |
928 |
| - |
929 |
| - |
930 | 915 | class Dinov2OnnxConfig(ViTOnnxConfig):
|
931 | 916 | DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,)
|
932 | 917 |
|
@@ -1606,41 +1591,6 @@ class Data2VecAudioOnnxConfig(AudioOnnxConfig):
|
1606 | 1591 | NORMALIZED_CONFIG_CLASS = NormalizedConfig
|
1607 | 1592 |
|
1608 | 1593 |
|
1609 |
| -class PerceiverDummyInputGenerator(DummyVisionInputGenerator): |
1610 |
| - def __init__( |
1611 |
| - self, |
1612 |
| - task: str, |
1613 |
| - normalized_config: NormalizedVisionConfig, |
1614 |
| - batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
1615 |
| - num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], |
1616 |
| - width: int = DEFAULT_DUMMY_SHAPES["width"], |
1617 |
| - height: int = DEFAULT_DUMMY_SHAPES["height"], |
1618 |
| - **kwargs, |
1619 |
| - ): |
1620 |
| - super().__init__( |
1621 |
| - task=task, |
1622 |
| - normalized_config=normalized_config, |
1623 |
| - batch_size=batch_size, |
1624 |
| - num_channels=num_channels, |
1625 |
| - width=width, |
1626 |
| - height=height, |
1627 |
| - **kwargs, |
1628 |
| - ) |
1629 |
| - |
1630 |
| - from transformers.onnx.utils import get_preprocessor |
1631 |
| - |
1632 |
| - preprocessor = get_preprocessor(normalized_config._name_or_path) |
1633 |
| - if preprocessor is not None and hasattr(preprocessor, "size"): |
1634 |
| - self.height = preprocessor.size.get("height", self.height) |
1635 |
| - self.width = preprocessor.size.get("width", self.width) |
1636 |
| - |
1637 |
| - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
1638 |
| - input_ = super().generate( |
1639 |
| - input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype |
1640 |
| - ) |
1641 |
| - return input_ |
1642 |
| - |
1643 |
| - |
1644 | 1594 | class PerceiverOnnxConfig(TextAndVisionOnnxConfig):
|
1645 | 1595 | NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
|
1646 | 1596 | DUMMY_INPUT_GENERATOR_CLASSES = (
|
|
0 commit comments