diff --git a/optimum/exporters/base.py b/optimum/exporters/base.py index 17e1265e74..63246f9387 100644 --- a/optimum/exporters/base.py +++ b/optimum/exporters/base.py @@ -14,8 +14,238 @@ # limitations under the License. """Base exporters config.""" -from abc import ABC +import copy +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from transformers.utils import is_torch_available + +from ..utils import ( + DEFAULT_DUMMY_SHAPES, + DummyInputGenerator, + logging, +) +from ..utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION +from ..utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION +from ..utils.doc import add_dynamic_docstring +from ..utils.import_utils import is_torch_version, is_transformers_version + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + +logger = logging.get_logger(__name__) + + +GENERATE_DUMMY_DOCSTRING = r""" + Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used. + + Args: + framework (`str`, defaults to `"pt"`): + The framework for which to create the dummy inputs. + batch_size (`int`, defaults to {batch_size}): + The batch size to use in the dummy inputs. + sequence_length (`int`, defaults to {sequence_length}): + The sequence length to use in the dummy inputs. + num_choices (`int`, defaults to {num_choices}): + The number of candidate answers provided for multiple choice task. + image_width (`int`, defaults to {width}): + The width to use in the dummy inputs for vision tasks. + image_height (`int`, defaults to {height}): + The height to use in the dummy inputs for vision tasks. + num_channels (`int`, defaults to {num_channels}): + The number of channels to use in the dummpy inputs for vision tasks. + feature_size (`int`, defaults to {feature_size}): + The number of features to use in the dummpy inputs for audio tasks in case it is not raw audio. + This is for example the number of STFT bins or MEL bins. + nb_max_frames (`int`, defaults to {nb_max_frames}): + The number of frames to use in the dummpy inputs for audio tasks in case the input is not raw audio. + audio_sequence_length (`int`, defaults to {audio_sequence_length}): + The number of frames to use in the dummpy inputs for audio tasks in case the input is raw audio. + + Returns: + `Dict[str, [tf.Tensor, torch.Tensor]]`: A dictionary mapping the input names to dummy tensors in the proper framework format. +""" class ExportConfig(ABC): - pass + """ + Base class describing metadata on how to export the model through the ONNX format. + + Class attributes: + + - NORMALIZED_CONFIG_CLASS (`Type`) -- A class derived from [`~optimum.utils.NormalizedConfig`] specifying how to + normalize the model config. + - DUMMY_INPUT_GENERATOR_CLASSES (`Tuple[Type]`) -- A tuple of classes derived from + [`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs. + - ATOL_FOR_VALIDATION (`Union[float, Dict[str, float]]`) -- A float or a dictionary mapping task names to float, + where the float values represent the absolute tolerance value to use during model conversion validation. + - MIN_TORCH_VERSION (`packaging.version.Version`, defaults to [`~optimum.exporters.utils.TORCH_MINIMUM_VERSION`]) -- The + minimum torch version supporting the export of the model. + - MIN_TRANSFORMERS_VERSION (`packaging.version.Version`, defaults to + [`~optimum.exporters.utils.TRANSFORMERS_MINIMUM_VERSION`] -- The minimum transformers version supporting the + export of the model. Not always up-to-date or accurate. This is more for internal use. + - PATCHING_SPECS (`Optional[List[PatchingSpec]]`, defaults to `None`) -- Specify which operators / modules should be + patched before performing the export, and how. This is useful when some operator is not supported for instance. + + Args: + config (`transformers.PretrainedConfig`): + The model configuration. + task (`str`, defaults to `"feature-extraction"`): + The task the model should be exported for. + int_dtype (`str`, defaults to `"int64"`): + The data type of integer tensors, could be ["int64", "int32", "int8"], default to "int64". + float_dtype (`str`, defaults to `"fp32"`): + The data type of float tensors, could be ["fp32", "fp16", "bf16"], default to "fp32". + """ + + NORMALIZED_CONFIG_CLASS = None + DUMMY_INPUT_GENERATOR_CLASSES = () + ATOL_FOR_VALIDATION: Union[float, Dict[str, float]] = 1e-5 + MIN_TORCH_VERSION = GLOBAL_MIN_TORCH_VERSION + MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION + _TASK_TO_COMMON_OUTPUTS = { + "audio-classification": ["logits"], + "audio-frame-classification": ["logits"], + "automatic-speech-recognition": ["logits"], + "audio-xvector": ["logits"], # for onnx : ["logits", "embeddings"] + "depth-estimation": ["predicted_depth"], + "document-question-answering": ["logits"], + "feature-extraction": ["last_hidden_state"], # for neuron : ["last_hidden_state", "pooler_output"] + "fill-mask": ["logits"], + "image-classification": ["logits"], + "image-segmentation": ["logits"], + "image-to-text": ["logits"], + "image-to-image": ["reconstruction"], + "mask-generation": ["logits"], + "masked-im": ["reconstruction"], + "multiple-choice": ["logits"], + "object-detection": ["logits", "pred_boxes"], + "question-answering": ["start_logits", "end_logits"], + "semantic-segmentation": ["logits"], + "text2text-generation": ["logits"], + "text-classification": ["logits"], + "text-generation": ["logits"], + "time-series-forecasting": ["prediction_outputs"], + "token-classification": ["logits"], + "visual-question-answering": ["logits"], + "zero-shot-image-classification": ["logits_per_image", "logits_per_text", "text_embeds", "image_embeds"], + "zero-shot-object-detection": ["logits", "pred_boxes", "text_embeds", "image_embeds"], + } + # TODO : add _MODEL_PATCHER + patch_model_for_export + # _MODEL_PATCHER = ModelPatcher + + def __init__( + self, + config: "PretrainedConfig", + task: str, + int_dtype: str = "int64", + float_dtype: str = "fp32", + ): + self.task = task + self._config = config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self.int_dtype = int_dtype + self.float_dtype = float_dtype + + def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: + """ + Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`. + Each dummy input generator is independent, so this method instantiates the first generator, and + forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch + size. Override this method for custom behavior. + """ + return [cls_(self.task, self._normalized_config, **kwargs) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES] + + @property + @abstractmethod + def inputs(self) -> Dict[str, Dict[int, str]]: + """ + Dict containing the axis definition of the input tensors to provide to the model. + + Returns: + `Dict[str, Dict[int, str]]`: A mapping of each input name to a mapping of axis position to the axes symbolic name. + """ + raise NotImplementedError() + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + """ + Dict containing the axis definition of the output tensors to provide to the model. + + Returns: + `Dict[str, Dict[int, str]]`: A mapping of each output name to a mapping of axis position to the axes symbolic name. + """ + common_outputs = self._TASK_TO_COMMON_OUTPUTS[self.task] + return copy.deepcopy(common_outputs) + + @property + def values_override(self) -> Optional[Dict[str, Any]]: + """ + Dictionary of keys to override in the model's config before exporting. + + Returns: + `Optional[Dict[str, Any]]`: A dictionary specifying the configuration items to override. + """ + if hasattr(self._config, "use_cache"): + return {"use_cache": False} + + return None + + @property + def is_transformers_support_available(self) -> bool: + """ + Whether the installed version of Transformers allows for the ONNX export. + + Returns: + `bool`: Whether the install version of Transformers is compatible with the model. + + """ + return is_transformers_version(">=", self.MIN_TRANSFORMERS_VERSION.base_version) + + @property + def is_torch_support_available(self) -> bool: + """ + Whether the installed version of PyTorch allows for the ONNX export. + + Returns: + `bool`: Whether the installed version of PyTorch is compatible with the model. + """ + if is_torch_available(): + return is_torch_version(">=", self.MIN_TORCH_VERSION.base_version) + + return False + + @add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES) + def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict: + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + dummy_inputs = {} + for input_name in self.inputs: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = dummy_input_gen.generate( + input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to ' + "the model exporters config." + ) + return dummy_inputs + + @classmethod + def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Flatten nested structure in dummy inputs, e.g `addition_embed_type` of unet model. + """ + flatten = {} + for name, value in inputs.items(): + if isinstance(value, dict): + for sub_name, sub_value in value.items(): + flatten[sub_name] = sub_value + else: + flatten[name] = value + return flatten diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 43468a15c0..b64de5df28 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -21,7 +21,7 @@ import itertools import os import re -from abc import ABC, abstractmethod +from abc import ABC from collections import OrderedDict from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union @@ -41,18 +41,15 @@ is_diffusers_available, logging, ) -from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION -from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION from ...utils.doc import add_dynamic_docstring from ...utils.import_utils import ( is_onnx_available, is_onnxruntime_available, - is_torch_version, is_transformers_version, ) from ..base import ExportConfig from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME -from .model_patcher import ModelPatcher, Seq2SeqModelPatcher +from .model_patcher import DecoderModelPatcher, ModelPatcher, Seq2SeqModelPatcher # TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization @@ -63,10 +60,11 @@ if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel + from .model_patcher import PatchingSpec + if is_diffusers_available(): from diffusers import ModelMixin - from .model_patcher import PatchingSpec logger = logging.get_logger(__name__) @@ -103,47 +101,13 @@ class OnnxConfig(ExportConfig, ABC): - """ - Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format. - - Class attributes: - - - NORMALIZED_CONFIG_CLASS (`Type`) -- A class derived from [`~optimum.utils.NormalizedConfig`] specifying how to - normalize the model config. - - DUMMY_INPUT_GENERATOR_CLASSES (`Tuple[Type]`) -- A tuple of classes derived from - [`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs. - - ATOL_FOR_VALIDATION (`Union[float, Dict[str, float]]`) -- A float or a dictionary mapping task names to float, - where the float values represent the absolute tolerance value to use during model conversion validation. - - DEFAULT_ONNX_OPSET (`int`, defaults to 11) -- The default ONNX opset to use for the ONNX export. - - MIN_TORCH_VERSION (`packaging.version.Version`, defaults to [`~optimum.exporters.onnx.utils.TORCH_MINIMUM_VERSION`]) -- The - minimum torch version supporting the export of the model to ONNX. - - MIN_TRANSFORMERS_VERSION (`packaging.version.Version`, defaults to - [`~optimum.exporters.onnx.utils.TRANSFORMERS_MINIMUM_VERSION`] -- The minimum transformers version supporting the - export of the model to ONNX. Not always up-to-date or accurate. This is more for internal use. - - PATCHING_SPECS (`Optional[List[PatchingSpec]]`, defaults to `None`) -- Specify which operators / modules should be - patched before performing the export, and how. This is useful when some operator is not supported in ONNX for - instance. - - Args: - config (`transformers.PretrainedConfig`): - The model configuration. - task (`str`, defaults to `"feature-extraction"`): - The task the model should be exported for. - int_dtype (`str`, defaults to `"int64"`): - The data type of integer tensors, could be ["int64", "int32", "int8"], default to "int64". - float_dtype (`str`, defaults to `"fp32"`): - The data type of float tensors, could be ["fp32", "fp16", "bf16"], default to "fp32". - """ - - NORMALIZED_CONFIG_CLASS = None - DUMMY_INPUT_GENERATOR_CLASSES = () DEFAULT_ONNX_OPSET = 11 - ATOL_FOR_VALIDATION: Union[float, Dict[str, float]] = 1e-5 - MIN_TORCH_VERSION = GLOBAL_MIN_TORCH_VERSION - MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION - PATCHING_SPECS: Optional[List["PatchingSpec"]] = None VARIANTS = {"default": "The default ONNX variant."} DEFAULT_VARIANT = "default" + # TODO: move PATCHING_SPECS to ExportConfig + PATCHING_SPECS: Optional[List["PatchingSpec"]] = None + _MODEL_PATCHER = ModelPatcher + _TASK_TO_COMMON_OUTPUTS = { "audio-classification": OrderedDict({"logits": {0: "batch_size"}}), "audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), @@ -210,53 +174,12 @@ def __init__( float_dtype: str = "fp32", legacy: bool = False, ): - self.task = task - self.int_dtype = int_dtype - self.float_dtype = float_dtype + super().__init__(config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype) - self._config = config - self._preprocessors = preprocessors - self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) self.variant = "default" + self._preprocessors = preprocessors self.legacy = legacy - def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: - """ - Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`. - Each dummy input generator is independent, so this method instantiates the first generator, and - forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch - size. Override this method for custom behavior. - """ - first_inputs_gen = self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config, **kwargs) - dummy_inputs_generators = [ - cls_(self.task, self._normalized_config, **kwargs) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:] - ] - dummy_inputs_generators.insert(0, first_inputs_gen) - - return dummy_inputs_generators - - @property - @abstractmethod - def inputs(self) -> Dict[str, Dict[int, str]]: - """ - Dict containing the axis definition of the input tensors to provide to the model. - - Returns: - `Dict[str, Dict[int, str]]`: A mapping of each input name to a mapping of axis position to the axes symbolic name. - """ - raise NotImplementedError() - - @property - def outputs(self) -> Dict[str, Dict[int, str]]: - """ - Dict containing the axis definition of the output tensors to provide to the model. - - Returns: - `Dict[str, Dict[int, str]]`: A mapping of each output name to a mapping of axis position to the axes symbolic name. - """ - common_outputs = self._TASK_TO_COMMON_OUTPUTS[self.task] - return copy.deepcopy(common_outputs) - @property def variant(self) -> str: """ @@ -354,48 +277,6 @@ def fix_dynamic_axes( del onnx_model gc.collect() - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> ModelPatcher: - return ModelPatcher(self, model, model_kwargs=model_kwargs) - - @property - def values_override(self) -> Optional[Dict[str, Any]]: - """ - Dictionary of keys to override in the model's config before exporting. - - Returns: - `Optional[Dict[str, Any]]`: A dictionary specifying the configuration items to override. - """ - if hasattr(self._config, "use_cache"): - return {"use_cache": False} - - return None - - @property - def is_transformers_support_available(self) -> bool: - """ - Whether the installed version of Transformers allows for the ONNX export. - - Returns: - `bool`: Whether the install version of Transformers is compatible with the model. - - """ - return is_transformers_version(">=", self.MIN_TRANSFORMERS_VERSION.base_version) - - @property - def is_torch_support_available(self) -> bool: - """ - Whether the installed version of PyTorch allows for the ONNX export. - - Returns: - `bool`: Whether the installed version of PyTorch is compatible with the model. - """ - if is_torch_available(): - return is_torch_version(">=", self.MIN_TORCH_VERSION.base_version) - - return False - @property def torch_to_onnx_input_map(self) -> Dict[str, str]: """ @@ -461,27 +342,7 @@ def ordered_inputs(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) - ordered_inputs[name] = dynamic_axes return ordered_inputs - @add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES) - def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict: - dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) - - dummy_inputs = {} - for input_name in self.inputs: - input_was_inserted = False - for dummy_input_gen in dummy_inputs_generators: - if dummy_input_gen.supports_input(input_name): - dummy_inputs[input_name] = dummy_input_gen.generate( - input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype - ) - input_was_inserted = True - break - if not input_was_inserted: - raise RuntimeError( - f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to ' - "the model ONNX config." - ) - return dummy_inputs - + # TODO: use instead flatten_inputs and remove @classmethod def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]: """ @@ -566,6 +427,11 @@ def post_process_exported_models( return models_and_onnx_configs, onnx_files_subpaths + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + return self._MODEL_PATCHER(self, model, model_kwargs=model_kwargs) + class OnnxConfigWithPast(OnnxConfig, ABC): """ @@ -574,6 +440,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC): PAD_ATTENTION_MASK_TO_PAST: bool = False SUPPORTS_PAST: bool = True + _MODEL_PATCHER = DecoderModelPatcher def __init__( self, @@ -786,6 +653,7 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): """ DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator + _MODEL_PATCHER = Seq2SeqModelPatcher def __init__( self, @@ -918,11 +786,6 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.encoder.key"] = t[2] flattened_output[f"{name}.{idx}.encoder.value"] = t[3] - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> ModelPatcher: - return Seq2SeqModelPatcher(self, model, model_kwargs=model_kwargs) - def post_process_exported_models( self, path: Path, diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 69366d6be1..290e05d2b3 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -34,7 +34,6 @@ ) from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME -from .model_patcher import DecoderModelPatcher # TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization @@ -43,8 +42,6 @@ if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel - from .model_patcher import ModelPatcher - if is_tf_available(): from transformers import TFPreTrainedModel @@ -160,12 +157,6 @@ def post_process_exported_models( return models_and_onnx_configs, onnx_files_subpaths - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - # Refer to DecoderModelPatcher. - return DecoderModelPatcher(self, model, model_kwargs=model_kwargs) - class TextDecoderWithPositionIdsOnnxConfig(TextDecoderOnnxConfig): @property diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index f420ab39c6..0c51de3d35 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -104,8 +104,6 @@ from transformers import PretrainedConfig from transformers.modeling_utils import PreTrainedModel - from .model_patcher import ModelPatcher - if is_tf_available(): from transformers.modeling_tf_utils import TFPreTrainedModel @@ -410,11 +408,7 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) - - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return MistralModelPatcher(self, model, model_kwargs=model_kwargs) + _MODEL_PATCHER = MistralModelPatcher class MPTOnnxConfig(TextDecoderOnnxConfig): @@ -502,6 +496,10 @@ class FalconOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig DUMMY_PKV_GENERATOR_CLASS = FalconDummyPastKeyValuesGenerator + # we need to set output_attentions=True in the model input to avoid calling + # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export + _MODEL_PATCHER = FalconModelPatcher + def __init__( self, config: "PretrainedConfig", @@ -542,13 +540,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]: return common_inputs - # we need to set output_attentions=True in the model input to avoid calling - # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return FalconModelPatcher(self, model, model_kwargs=model_kwargs) - class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): @@ -1061,6 +1052,8 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]: class MgpstrOnnxConfig(ViTOnnxConfig): + _MODEL_PATCHER = MgpstrModelPatcher + @property def outputs(self) -> Dict[str, Dict[int, str]]: return { @@ -1069,15 +1062,14 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "wp_logits": {0: "batch_size"}, } - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return MgpstrModelPatcher(self, model, model_kwargs=model_kwargs) - class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig DEFAULT_ONNX_OPSET = 14 # Some bottleneck transformers models require a specific ONNX opset to be successfully exported. We put a rather high opset here for the export to work for all architectures. + # we need to set output_attentions=True in the model input to avoid calling + # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export + # due to the op torch.nn.functional.multi_head_attention_forward used for WavLM + _MODEL_PATCHER = SentenceTransformersTransformerPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -1093,14 +1085,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "sentence_embedding": {0: "batch_size"}, } - # we need to set output_attentions=True in the model input to avoid calling - # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export - # due to the op torch.nn.functional.multi_head_attention_forward used for WavLM - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return SentenceTransformersTransformerPatcher(self, model, model_kwargs=model_kwargs) - class CLIPNormalizedConfig(NormalizedTextAndVisionConfig): TEXT_CONFIG = "text_config" @@ -1109,6 +1093,7 @@ class CLIPNormalizedConfig(NormalizedTextAndVisionConfig): class CLIPVisionModelOnnxConfig(VisionOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + _MODEL_PATCHER = CLIPModelPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -1122,16 +1107,10 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs - def patch_model_for_export( - self, - model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> "ModelPatcher": - return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) - class CLIPOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig + _MODEL_PATCHER = CLIPModelPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -1150,15 +1129,10 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "image_embeds": {0: "image_batch_size"}, } - def patch_model_for_export( - self, - model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> "ModelPatcher": - return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) - class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig): + _MODEL_PATCHER = SentenceTransformersCLIPPatcher + @property def outputs(self) -> Dict[str, Dict[int, str]]: return { @@ -1166,11 +1140,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "image_embeds": {0: "image_batch_size"}, } - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return SentenceTransformersCLIPPatcher(self, model, model_kwargs=model_kwargs) - class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 @@ -1183,6 +1152,7 @@ class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig): num_layers="num_hidden_layers", allow_new=True, ) + _MODEL_PATCHER = CLIPModelPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -1202,15 +1172,10 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs - def patch_model_for_export( - self, - model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> "ModelPatcher": - return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) - class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig): + _MODEL_PATCHER = CLIPModelPatcher + @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = { @@ -1224,13 +1189,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs - def patch_model_for_export( - self, - model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> "ModelPatcher": - return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) - class SiglipNormalizedConfig(CLIPNormalizedConfig): pass @@ -1739,14 +1697,10 @@ class UniSpeechSATOnnxConfig(HubertOnnxConfig): class WavLMOnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 12 - # we need to set output_attentions=True in the model input to avoid calling # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export # due to the op torch.nn.functional.multi_head_attention_forward used for WavLM - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return WavLMModelPatcher(self, model, model_kwargs=model_kwargs) + _MODEL_PATCHER = WavLMModelPatcher class ASTDummyAudioInputGenerator(DummyAudioInputGenerator): @@ -1848,6 +1802,7 @@ class MusicgenOnnxConfig(OnnxSeq2SeqConfigWithPast): DummyIntGenerator, ) DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator + _MODEL_PATCHER = MusicgenModelPatcher def __init__( self, @@ -2024,11 +1979,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire 2: "encoder_sequence_length_out", } - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return MusicgenModelPatcher(self, model, model_kwargs=model_kwargs) - @property def torch_to_onnx_input_map(self) -> Dict[str, str]: if self._behavior is ConfigBehavior.DECODER: @@ -2140,6 +2090,7 @@ class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): "without-past": "The same as `with-past`, just without KV cache support. This is not a recommended export as slower than `with-past`.", } DEFAULT_VARIANT = "with-past" + _MODEL_PATCHER = SpeechT5ModelPatcher def __init__( self, @@ -2220,11 +2171,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return SpeechT5ModelPatcher(self, model, model_kwargs=model_kwargs) - @property def torch_to_onnx_input_map(self) -> Dict[str, str]: return {"encoder_outputs": "encoder_hidden_states"} @@ -2358,6 +2304,7 @@ class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator) + _MODEL_PATCHER = VisionEncoderDecoderPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -2391,11 +2338,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]: # so we can not initializer MBartONNXConfig with document-question-answering). return super().outputs - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return VisionEncoderDecoderPatcher(self, model, model_kwargs=model_kwargs) - class SamOnnxConfig(OnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0") @@ -2409,6 +2351,7 @@ class SamOnnxConfig(OnnxConfig): "split": "The vision encoder is exported as a separate vision_encoder.onnx, and the prompt encoder and mask decoder are exported as a prompt_encoder_mask_decoder.onnx. This allows to encoder the image only once for multiple point queries.", } DEFAULT_VARIANT = "split" + _MODEL_PATCHER = SAMModelPatcher def __init__( self, @@ -2463,11 +2406,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "pred_masks": {0: "batch_size", 1: "point_batch_size"}, } - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return SAMModelPatcher(self, model, model_kwargs=model_kwargs) - class Pix2StructNormalizedConfig(NormalizedSeq2SeqConfig): ENCODER_NUM_LAYERS = "vision_config.num_hidden_layers" diff --git a/optimum/exporters/tflite/base.py b/optimum/exporters/tflite/base.py index 3df230c33b..5780751cb0 100644 --- a/optimum/exporters/tflite/base.py +++ b/optimum/exporters/tflite/base.py @@ -14,12 +14,12 @@ # limitations under the License. """TensorFlow Lite configuration base classes.""" -from abc import ABC, abstractmethod +from abc import ABC from ctypes import ArgumentError from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from transformers.utils import is_tf_available @@ -148,34 +148,11 @@ class TFLiteConfig(ExportConfig, ABC): They are required or not depending on the model the `TFLiteConfig` is designed for. """ - NORMALIZED_CONFIG_CLASS: Type = None - DUMMY_INPUT_GENERATOR_CLASSES: Tuple[Type, ...] = () - ATOL_FOR_VALIDATION: Union[float, Dict[str, float]] = 1e-5 MANDATORY_AXES = () SUPPORTED_QUANTIZATION_APPROACHES: Union[ Dict[str, Tuple[QuantizationApproach, ...]], Tuple[QuantizationApproach, ...] ] = tuple(approach for approach in QuantizationApproach) - _TASK_TO_COMMON_OUTPUTS = { - "text-generation": ["logits"], - "feature-extraction": ["last_hidden_state"], - "image-classification": ["logits"], - "image-segmentation": ["logits", "pred_boxes", "pred_masks"], - "masked-im": ["logits"], - "fill-mask": ["logits"], - "multiple-choice": ["logits"], - "object-detection": ["logits", "pred_boxes"], - "question-answering": ["start_logits", "end_logits"], - "semantic-segmentation": ["logits"], - "text2text-generation": ["logits", "encoder_last_hidden_state"], - "text-classification": ["logits"], - "token-classification": ["logits"], - "automatic-speech-recognition": ["logits"], - "audio-classification": ["logits"], - "audio-frame-classification": ["logits"], - "audio-xvector": ["logits"], - } - def __init__( self, config: "PretrainedConfig", @@ -192,12 +169,11 @@ def __init__( point_batch_size: Optional[int] = None, nb_points_per_image: Optional[int] = None, ): - self._config = config - self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) self.mandatory_axes = () - self.task = task self._axes: Dict[str, int] = {} + super().__init__(config=config, task=task, int_dtype="int64", float_dtype="fp32") + # To avoid using **kwargs. axes_values = { "batch_size": batch_size, @@ -266,65 +242,8 @@ def _create_dummy_input_generator_classes(self) -> List["DummyInputGenerator"]: self._validate_mandatory_axes() return [cls_(self.task, self._normalized_config, **self._axes) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES] - @property - def values_override(self) -> Optional[Dict[str, Any]]: - """ - Dictionary of keys to override in the model's config before exporting. - - Returns: - `Optional[Dict[str, Any]]`: A dictionary specifying the configuration items to override. - """ - if hasattr(self._config, "use_cache"): - return {"use_cache": False} - - return None - - @property - @abstractmethod - def inputs(self) -> List[str]: - """ - List containing the names of the inputs the exported model should take. - - Returns: - `List[str]`: A list of input names. - """ - raise NotImplementedError() - - @property - def outputs(self) -> List[str]: - """ - List containing the names of the outputs the exported model should have. - - Returns: - `List[str]`: A list of output names. - """ - return self._TASK_TO_COMMON_OUTPUTS[self.task] - def generate_dummy_inputs(self) -> Dict[str, "tf.Tensor"]: - """ - Generates dummy inputs that the exported model should be able to process. - This method is actually used to determine the input specs that are needed for the export. - - Returns: - `Dict[str, tf.Tensor]`: A dictionary mapping input names to dummy tensors. - """ - dummy_inputs_generators = self._create_dummy_input_generator_classes() - dummy_inputs = {} - - for input_name in self.inputs: - input_was_inserted = False - for dummy_input_gen in dummy_inputs_generators: - if dummy_input_gen.supports_input(input_name): - dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework="tf") - input_was_inserted = True - break - if not input_was_inserted: - raise RuntimeError( - f'Could not generate dummy inputs for "{input_name}". Try adding a proper dummy input generator ' - "to the model TFLite config." - ) - - return dummy_inputs + return super().generate_dummy_inputs(framework="tf") @property def inputs_specs(self) -> List["TensorSpec"]: diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 3f497b5920..bb118f2c83 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -15,9 +15,11 @@ """Normalization configuration classes.""" import functools -from typing import Callable, Dict, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, Type, Union -from transformers import PretrainedConfig + +if TYPE_CHECKING: + from transformers import PretrainedConfig class NormalizedConfig: @@ -29,7 +31,7 @@ class NormalizedConfig: The config to normalize. """ - def __init__(self, config: Union[PretrainedConfig, Dict], allow_new: bool = False, **kwargs): + def __init__(self, config: Union["PretrainedConfig", Dict], allow_new: bool = False, **kwargs): self.config = config for key, value in kwargs.items(): if allow_new or hasattr(self, key.upper()): @@ -40,7 +42,7 @@ def __init__(self, config: Union[PretrainedConfig, Dict], allow_new: bool = Fals ) @classmethod - def with_args(cls, allow_new: bool = False, **kwargs) -> Callable[[PretrainedConfig], "NormalizedConfig"]: + def with_args(cls, allow_new: bool = False, **kwargs) -> Callable[["PretrainedConfig"], "NormalizedConfig"]: return functools.partial(cls, allow_new=allow_new, **kwargs) def __getattr__(self, attr_name):