|
39 | 39 | MPTOnnxConfig,
|
40 | 40 | PhiOnnxConfig,
|
41 | 41 | UNetOnnxConfig,
|
| 42 | + VaeEncoderOnnxConfig, |
42 | 43 | VisionOnnxConfig,
|
43 | 44 | )
|
44 | 45 | from optimum.exporters.onnx.model_patcher import ModelPatcher
|
|
54 | 55 | DummyVisionInputGenerator,
|
55 | 56 | FalconDummyPastKeyValuesGenerator,
|
56 | 57 | MistralDummyPastKeyValuesGenerator,
|
57 |
| - DummySeq2SeqDecoderTextInputGenerator |
58 | 58 | )
|
59 | 59 | from optimum.utils.normalized_config import NormalizedConfig, NormalizedTextConfig, NormalizedVisionConfig
|
60 | 60 |
|
@@ -1889,52 +1889,78 @@ def rename_ambiguous_inputs(self, inputs):
|
1889 | 1889 | class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
|
1890 | 1890 | pass
|
1891 | 1891 |
|
| 1892 | + |
1892 | 1893 | @register_in_tasks_manager("gemma2-text-encoder", *["feature-extraction"], library_name="diffusers")
|
1893 | 1894 | class Gemma2TextEncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
|
1894 | 1895 | @property
|
1895 | 1896 | def inputs(self) -> Dict[str, Dict[int, str]]:
|
1896 | 1897 | return {
|
1897 | 1898 | "input_ids": {0: "batch_size", 1: "sequence_length"},
|
1898 |
| - "attention_mask": {0: "batch_size", 1: "sequence_length"} |
| 1899 | + "attention_mask": {0: "batch_size", 1: "sequence_length"}, |
1899 | 1900 | }
|
1900 | 1901 |
|
1901 | 1902 |
|
1902 |
| -class DummySeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator): |
| 1903 | +class DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator): |
1903 | 1904 | SUPPORTED_INPUT_NAMES = (
|
1904 | 1905 | "decoder_input_ids",
|
1905 | 1906 | "decoder_attention_mask",
|
1906 | 1907 | "encoder_outputs",
|
1907 | 1908 | "encoder_hidden_states",
|
1908 |
| - "encoder_attention_mask" |
| 1909 | + "encoder_attention_mask", |
1909 | 1910 | )
|
1910 | 1911 |
|
1911 | 1912 |
|
1912 |
| -class DummySanaTransformerVisionInputGenerator(DummyVisionInputGenerator): |
1913 |
| - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
1914 |
| - if input_name not in ["sample", "latent_sample"]: |
1915 |
| - return super().generate(input_name, framework, int_dtype, float_dtype) |
1916 |
| - return self.random_float_tensor( |
1917 |
| - shape=[self.batch_size, self.num_channels, self.height, self.width], |
1918 |
| - framework=framework, |
1919 |
| - dtype=float_dtype, |
1920 |
| - ) |
| 1913 | +class DummySanaTransformerVisionInputGenerator(DummyUnetVisionInputGenerator): |
| 1914 | + def __init__( |
| 1915 | + self, |
| 1916 | + task: str, |
| 1917 | + normalized_config: NormalizedVisionConfig, |
| 1918 | + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
| 1919 | + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], |
| 1920 | + width: int = DEFAULT_DUMMY_SHAPES["width"] // 8, |
| 1921 | + height: int = DEFAULT_DUMMY_SHAPES["height"] // 8, |
| 1922 | + # Reduce img shape by 4 for FLUX to reduce memory usage on conversion |
| 1923 | + **kwargs, |
| 1924 | + ): |
| 1925 | + super().__init__(task, normalized_config, batch_size, num_channels, width=width, height=height, **kwargs) |
| 1926 | + |
1921 | 1927 |
|
1922 | 1928 | @register_in_tasks_manager("sana-transformer", *["semantic-segmentation"], library_name="diffusers")
|
1923 | 1929 | class SanaTransformerOpenVINOConfig(UNetOpenVINOConfig):
|
1924 | 1930 | NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
|
1925 | 1931 | image_size="sample_size",
|
1926 | 1932 | num_channels="in_channels",
|
1927 |
| - hidden_size="cross_attention_dim", |
| 1933 | + hidden_size="caption_channels", |
1928 | 1934 | vocab_size="attention_head_dim",
|
1929 | 1935 | allow_new=True,
|
1930 | 1936 | )
|
1931 |
| - DUMMY_INPUT_GENERATOR_CLASSES = (DummySanaTransformerVisionInputGenerator, DummySeq2SeqDecoderTextWithEncMaskInputGenerator) + UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:-1] |
| 1937 | + DUMMY_INPUT_GENERATOR_CLASSES = ( |
| 1938 | + DummySanaTransformerVisionInputGenerator, |
| 1939 | + DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator, |
| 1940 | + ) + UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:-1] |
| 1941 | + |
1932 | 1942 | @property
|
1933 | 1943 | def inputs(self):
|
1934 | 1944 | common_inputs = super().inputs
|
1935 | 1945 | common_inputs["encoder_attention_mask"] = {0: "batch_size", 1: "sequence_length"}
|
1936 | 1946 | return common_inputs
|
1937 | 1947 |
|
| 1948 | + def rename_ambiguous_inputs(self, inputs): |
| 1949 | + # The input name in the model signature is `x, hence the export input name is updated. |
| 1950 | + hidden_states = inputs.pop("sample", None) |
| 1951 | + if hidden_states is not None: |
| 1952 | + inputs["hidden_states"] = hidden_states |
| 1953 | + return inputs |
| 1954 | + |
| 1955 | + |
| 1956 | +@register_in_tasks_manager("dcae-encoder", *["semantic-segmentation"], library_name="diffusers") |
| 1957 | +class DcaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig): |
| 1958 | + @property |
| 1959 | + def outputs(self) -> Dict[str, Dict[int, str]]: |
| 1960 | + return { |
| 1961 | + "latent": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, |
| 1962 | + } |
| 1963 | + |
1938 | 1964 |
|
1939 | 1965 | class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
|
1940 | 1966 | SUPPORTED_INPUT_NAMES = (
|
|
0 commit comments