43
43
from optimum .exporters .tasks import TasksManager
44
44
from optimum .utils import DEFAULT_DUMMY_SHAPES
45
45
from optimum .utils .input_generators import (
46
+ DTYPE_MAPPER ,
46
47
DummyInputGenerator ,
47
48
DummyPastKeyValuesGenerator ,
49
+ DummySeq2SeqDecoderTextInputGenerator ,
48
50
DummyTextInputGenerator ,
49
51
DummyTimestepInputGenerator ,
50
52
DummyVisionInputGenerator ,
63
65
DBRXModelPatcher ,
64
66
DeciLMModelPatcher ,
65
67
FalconModelPatcher ,
68
+ FluxTransfromerModelPatcher ,
66
69
Gemma2ModelPatcher ,
67
70
GptNeoxJapaneseModelPatcher ,
68
71
GptNeoxModelPatcher ,
@@ -96,9 +99,9 @@ def init_model_configs():
96
99
"transformers" ,
97
100
"LlavaNextForConditionalGeneration" ,
98
101
)
99
- TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS [
100
- "image- text-to-text"
101
- ] = TasksManager . _TRANSFORMERS_TASKS_TO_MODEL_LOADERS [ "text-generation" ]
102
+ TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["image-text-to-text" ] = (
103
+ TasksManager . _TRANSFORMERS_TASKS_TO_MODEL_LOADERS [ " text-generation" ]
104
+ )
102
105
103
106
supported_model_types = [
104
107
"_SUPPORTED_MODEL_TYPE" ,
@@ -1576,7 +1579,7 @@ def patch_model_for_export(
1576
1579
1577
1580
1578
1581
class PooledProjectionsDummyInputGenerator (DummyInputGenerator ):
1579
- SUPPORTED_INPUT_NAMES = "pooled_projections"
1582
+ SUPPORTED_INPUT_NAMES = [ "pooled_projections" ]
1580
1583
1581
1584
def __init__ (
1582
1585
self ,
@@ -1600,8 +1603,10 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1600
1603
1601
1604
1602
1605
class DummyTransformerTimestpsInputGenerator (DummyTimestepInputGenerator ):
1606
+ SUPPORTED_INPUT_NAMES = ("timestep" , "text_embeds" , "time_ids" , "timestep_cond" , "guidance" )
1607
+
1603
1608
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1604
- if input_name == "timestep" :
1609
+ if input_name in [ "timestep" , "guidance" ] :
1605
1610
shape = [self .batch_size ]
1606
1611
return self .random_float_tensor (shape , max_value = self .vocab_size , framework = framework , dtype = float_dtype )
1607
1612
return super ().generate (input_name , framework , int_dtype , float_dtype )
@@ -1642,3 +1647,113 @@ def patch_model_for_export(
1642
1647
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1643
1648
) -> ModelPatcher :
1644
1649
return ModelPatcher (self , model , model_kwargs = model_kwargs )
1650
+
1651
+
1652
+ class DummyFluxTransformerInputGenerator (DummyVisionInputGenerator ):
1653
+ SUPPORTED_INPUT_NAMES = (
1654
+ "pixel_values" ,
1655
+ "pixel_mask" ,
1656
+ "sample" ,
1657
+ "latent_sample" ,
1658
+ "hidden_states" ,
1659
+ "img_ids" ,
1660
+ )
1661
+
1662
+ def __init__ (
1663
+ self ,
1664
+ task : str ,
1665
+ normalized_config : NormalizedVisionConfig ,
1666
+ batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
1667
+ num_channels : int = DEFAULT_DUMMY_SHAPES ["num_channels" ],
1668
+ width : int = DEFAULT_DUMMY_SHAPES ["width" ],
1669
+ height : int = DEFAULT_DUMMY_SHAPES ["height" ],
1670
+ ** kwargs ,
1671
+ ):
1672
+
1673
+ super ().__init__ (task , normalized_config , batch_size , num_channels , width , height , ** kwargs )
1674
+ if getattr (normalized_config , "in_channels" , None ):
1675
+ self .num_channels = normalized_config .in_channels // 4
1676
+
1677
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1678
+ if input_name in ["hidden_states" , "sample" ]:
1679
+ shape = [self .batch_size , (self .height // 2 ) * (self .width // 2 ), self .num_channels * 4 ]
1680
+ return self .random_float_tensor (shape , framework = framework , dtype = float_dtype )
1681
+ if input_name == "img_ids" :
1682
+ return self .prepare_image_ids (framework , int_dtype , float_dtype )
1683
+
1684
+ return super ().generate (input_name , framework , int_dtype , float_dtype )
1685
+
1686
+ def prepare_image_ids (self , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1687
+ img_ids_height = self .height // 2
1688
+ img_ids_width = self .width // 2
1689
+ if framework == "pt" :
1690
+ import torch
1691
+
1692
+ latent_image_ids = torch .zeros (img_ids_height , img_ids_width , 3 )
1693
+ latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (img_ids_height )[:, None ]
1694
+ latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (img_ids_width )[None , :]
1695
+
1696
+ latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
1697
+
1698
+ latent_image_ids = latent_image_ids [None , :].repeat (self .batch_size , 1 , 1 , 1 )
1699
+ latent_image_ids = latent_image_ids .reshape (
1700
+ self .batch_size , latent_image_id_height * latent_image_id_width , latent_image_id_channels
1701
+ )
1702
+ latent_image_ids .to (DTYPE_MAPPER .pt (float_dtype ))
1703
+ return latent_image_ids
1704
+ if framework == "np" :
1705
+ import numpy as np
1706
+
1707
+ latent_image_ids = np .zeros (img_ids_height , img_ids_width , 3 )
1708
+ latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + np .arange (img_ids_height )[:, None ]
1709
+ latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + np .arange (img_ids_width )[None , :]
1710
+
1711
+ latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
1712
+
1713
+ latent_image_ids = np .tile (latent_image_ids [None , :], (self .batch_size , 1 , 1 , 1 ))
1714
+ latent_image_ids = latent_image_ids .reshape (
1715
+ self .batch_size , latent_image_id_height * latent_image_id_width , latent_image_id_channels
1716
+ )
1717
+ latent_image_ids .astype (DTYPE_MAPPER .np [float_dtype ])
1718
+ return latent_image_ids
1719
+
1720
+
1721
+ class DummyFluxTextInputGenerator (DummySeq2SeqDecoderTextInputGenerator ):
1722
+ SUPPORTED_INPUT_NAMES = (
1723
+ "decoder_input_ids" ,
1724
+ "decoder_attention_mask" ,
1725
+ "encoder_outputs" ,
1726
+ "encoder_hidden_states" ,
1727
+ "txt_ids" ,
1728
+ )
1729
+
1730
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1731
+ if input_name == "txt_ids" :
1732
+ return self .constant_tensor ([self .batch_size , self .sequence_length , 3 ], 0 , DTYPE_MAPPER .pt (float_dtype ))
1733
+ return super ().generate (input_name , framework , int_dtype , float_dtype )
1734
+
1735
+
1736
+ @register_in_tasks_manager ("flux-transformer" , * ["semantic-segmentation" ], library_name = "diffusers" )
1737
+ class FluxTransformerOpenVINOConfig (SD3TransformerOpenVINOConfig ):
1738
+ DUMMY_INPUT_GENERATOR_CLASSES = (
1739
+ DummyTransformerTimestpsInputGenerator ,
1740
+ DummyFluxTransformerInputGenerator ,
1741
+ DummyFluxTextInputGenerator ,
1742
+ PooledProjectionsDummyInputGenerator ,
1743
+ )
1744
+
1745
+ @property
1746
+ def inputs (self ):
1747
+ common_inputs = super ().inputs
1748
+ common_inputs .pop ("sample" , None )
1749
+ common_inputs ["hidden_states" ] = {0 : "batch_size" , 1 : "packed_height_width" }
1750
+ common_inputs ["txt_ids" ] = {0 : "batch_size" , 1 : "sequence_length" }
1751
+ common_inputs ["img_ids" ] = {0 : "batch_size" , 1 : "packed_height_width" }
1752
+ if getattr (self ._normalized_config , "guidance_embeds" , False ):
1753
+ common_inputs ["guidance" ] = {0 : "batch_size" }
1754
+ return common_inputs
1755
+
1756
+ def patch_model_for_export (
1757
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1758
+ ) -> ModelPatcher :
1759
+ return FluxTransfromerModelPatcher (self , model , model_kwargs = model_kwargs )
0 commit comments