Skip to content

Commit 7c535b2

Browse files
committed
add flux
1 parent be4624d commit 7c535b2

13 files changed

+373
-49
lines changed

optimum/commands/export/openvino.py

+6
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,12 @@ def run(self):
318318
from optimum.intel import OVStableDiffusionPipeline
319319

320320
model_cls = OVStableDiffusionPipeline
321+
elif class_name == "StableDiffusion3Pipeline":
322+
from optimum.intel import OVStableDiffusion3Pipeline
323+
model_cls = OVStableDiffusion3Pipeline
324+
elif class_name == "FluxPipeline":
325+
from optimum.intel import OVFluxPipeline
326+
model_cls = OVFluxPipeline
321327
else:
322328
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")
323329

optimum/exporters/openvino/convert.py

+95-2
Original file line numberDiff line numberDiff line change
@@ -917,9 +917,19 @@ def get_diffusion_models_for_export_ext(
917917
except ImportError:
918918
is_sd3 = False
919919

920-
if not is_sd3:
920+
try:
921+
from diffusers import FluxPipeline
922+
923+
is_flux = isinstance(pipeline, FluxPipeline)
924+
except ImportError:
925+
is_flux = False
926+
927+
if not is_sd3 and not is_flux:
921928
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
922-
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
929+
if is_sd3:
930+
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
931+
else:
932+
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
923933

924934
return None, models_for_export
925935

@@ -1021,3 +1031,86 @@ def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
10211031
models_for_export["text_encoder_3"] = (text_encoder_3, export_config)
10221032

10231033
return models_for_export
1034+
1035+
1036+
def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
1037+
models_for_export = {}
1038+
1039+
# Text encoder
1040+
text_encoder = getattr(pipeline, "text_encoder", None)
1041+
if text_encoder is not None:
1042+
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
1043+
model=text_encoder,
1044+
exporter=exporter,
1045+
library_name="diffusers",
1046+
task="feature-extraction",
1047+
model_type="clip-text-model",
1048+
)
1049+
text_encoder_export_config = text_encoder_config_constructor(
1050+
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1051+
)
1052+
models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
1053+
1054+
transformer = pipeline.transformer
1055+
transformer.config.text_encoder_projection_dim = transformer.config.joint_attention_dim
1056+
transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
1057+
transformer.config.time_cond_proj_dim = None
1058+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1059+
model=transformer,
1060+
exporter=exporter,
1061+
library_name="diffusers",
1062+
task="semantic-segmentation",
1063+
model_type="flux-transformer",
1064+
)
1065+
transformer_export_config = export_config_constructor(
1066+
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
1067+
)
1068+
models_for_export["transformer"] = (transformer, transformer_export_config)
1069+
1070+
# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
1071+
vae_encoder = copy.deepcopy(pipeline.vae)
1072+
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
1073+
vae_config_constructor = TasksManager.get_exporter_config_constructor(
1074+
model=vae_encoder,
1075+
exporter=exporter,
1076+
library_name="diffusers",
1077+
task="semantic-segmentation",
1078+
model_type="vae-encoder",
1079+
)
1080+
vae_encoder_export_config = vae_config_constructor(
1081+
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1082+
)
1083+
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
1084+
1085+
# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
1086+
vae_decoder = copy.deepcopy(pipeline.vae)
1087+
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
1088+
vae_config_constructor = TasksManager.get_exporter_config_constructor(
1089+
model=vae_decoder,
1090+
exporter=exporter,
1091+
library_name="diffusers",
1092+
task="semantic-segmentation",
1093+
model_type="vae-decoder",
1094+
)
1095+
vae_decoder_export_config = vae_config_constructor(
1096+
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1097+
)
1098+
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
1099+
1100+
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
1101+
if text_encoder_2 is not None:
1102+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1103+
model=text_encoder_2,
1104+
exporter=exporter,
1105+
library_name="diffusers",
1106+
task="feature-extraction",
1107+
model_type="t5-encoder-model",
1108+
)
1109+
export_config = export_config_constructor(
1110+
text_encoder_2.config,
1111+
int_dtype=int_dtype,
1112+
float_dtype=float_dtype,
1113+
)
1114+
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
1115+
1116+
return models_for_export

optimum/exporters/openvino/model_configs.py

+120-5
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@
4343
from optimum.exporters.tasks import TasksManager
4444
from optimum.utils import DEFAULT_DUMMY_SHAPES
4545
from optimum.utils.input_generators import (
46+
DTYPE_MAPPER,
4647
DummyInputGenerator,
4748
DummyPastKeyValuesGenerator,
49+
DummySeq2SeqDecoderTextInputGenerator,
4850
DummyTextInputGenerator,
4951
DummyTimestepInputGenerator,
5052
DummyVisionInputGenerator,
@@ -63,6 +65,7 @@
6365
DBRXModelPatcher,
6466
DeciLMModelPatcher,
6567
FalconModelPatcher,
68+
FluxTransfromerModelPatcher,
6669
Gemma2ModelPatcher,
6770
GptNeoxJapaneseModelPatcher,
6871
GptNeoxModelPatcher,
@@ -96,9 +99,9 @@ def init_model_configs():
9699
"transformers",
97100
"LlavaNextForConditionalGeneration",
98101
)
99-
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
100-
"image-text-to-text"
101-
] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
102+
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["image-text-to-text"] = (
103+
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
104+
)
102105

103106
supported_model_types = [
104107
"_SUPPORTED_MODEL_TYPE",
@@ -1576,7 +1579,7 @@ def patch_model_for_export(
15761579

15771580

15781581
class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
1579-
SUPPORTED_INPUT_NAMES = "pooled_projections"
1582+
SUPPORTED_INPUT_NAMES = ["pooled_projections"]
15801583

15811584
def __init__(
15821585
self,
@@ -1600,8 +1603,10 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
16001603

16011604

16021605
class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator):
1606+
SUPPORTED_INPUT_NAMES = ("timestep", "text_embeds", "time_ids", "timestep_cond", "guidance")
1607+
16031608
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1604-
if input_name == "timestep":
1609+
if input_name in ["timestep", "guidance"]:
16051610
shape = [self.batch_size]
16061611
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
16071612
return super().generate(input_name, framework, int_dtype, float_dtype)
@@ -1642,3 +1647,113 @@ def patch_model_for_export(
16421647
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
16431648
) -> ModelPatcher:
16441649
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1650+
1651+
1652+
class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
1653+
SUPPORTED_INPUT_NAMES = (
1654+
"pixel_values",
1655+
"pixel_mask",
1656+
"sample",
1657+
"latent_sample",
1658+
"hidden_states",
1659+
"img_ids",
1660+
)
1661+
1662+
def __init__(
1663+
self,
1664+
task: str,
1665+
normalized_config: NormalizedVisionConfig,
1666+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1667+
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
1668+
width: int = DEFAULT_DUMMY_SHAPES["width"],
1669+
height: int = DEFAULT_DUMMY_SHAPES["height"],
1670+
**kwargs,
1671+
):
1672+
1673+
super().__init__(task, normalized_config, batch_size, num_channels, width, height, **kwargs)
1674+
if getattr(normalized_config, "in_channels", None):
1675+
self.num_channels = normalized_config.in_channels // 4
1676+
1677+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1678+
if input_name in ["hidden_states", "sample"]:
1679+
shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels * 4]
1680+
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
1681+
if input_name == "img_ids":
1682+
return self.prepare_image_ids(framework, int_dtype, float_dtype)
1683+
1684+
return super().generate(input_name, framework, int_dtype, float_dtype)
1685+
1686+
def prepare_image_ids(self, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1687+
img_ids_height = self.height // 2
1688+
img_ids_width = self.width // 2
1689+
if framework == "pt":
1690+
import torch
1691+
1692+
latent_image_ids = torch.zeros(img_ids_height, img_ids_width, 3)
1693+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(img_ids_height)[:, None]
1694+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(img_ids_width)[None, :]
1695+
1696+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
1697+
1698+
latent_image_ids = latent_image_ids[None, :].repeat(self.batch_size, 1, 1, 1)
1699+
latent_image_ids = latent_image_ids.reshape(
1700+
self.batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
1701+
)
1702+
latent_image_ids.to(DTYPE_MAPPER.pt(float_dtype))
1703+
return latent_image_ids
1704+
if framework == "np":
1705+
import numpy as np
1706+
1707+
latent_image_ids = np.zeros(img_ids_height, img_ids_width, 3)
1708+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + np.arange(img_ids_height)[:, None]
1709+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + np.arange(img_ids_width)[None, :]
1710+
1711+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
1712+
1713+
latent_image_ids = np.tile(latent_image_ids[None, :], (self.batch_size, 1, 1, 1))
1714+
latent_image_ids = latent_image_ids.reshape(
1715+
self.batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
1716+
)
1717+
latent_image_ids.astype(DTYPE_MAPPER.np[float_dtype])
1718+
return latent_image_ids
1719+
1720+
1721+
class DummyFluxTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
1722+
SUPPORTED_INPUT_NAMES = (
1723+
"decoder_input_ids",
1724+
"decoder_attention_mask",
1725+
"encoder_outputs",
1726+
"encoder_hidden_states",
1727+
"txt_ids",
1728+
)
1729+
1730+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1731+
if input_name == "txt_ids":
1732+
return self.constant_tensor([self.batch_size, self.sequence_length, 3], 0, DTYPE_MAPPER.pt(float_dtype))
1733+
return super().generate(input_name, framework, int_dtype, float_dtype)
1734+
1735+
1736+
@register_in_tasks_manager("flux-transformer", *["semantic-segmentation"], library_name="diffusers")
1737+
class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
1738+
DUMMY_INPUT_GENERATOR_CLASSES = (
1739+
DummyTransformerTimestpsInputGenerator,
1740+
DummyFluxTransformerInputGenerator,
1741+
DummyFluxTextInputGenerator,
1742+
PooledProjectionsDummyInputGenerator,
1743+
)
1744+
1745+
@property
1746+
def inputs(self):
1747+
common_inputs = super().inputs
1748+
common_inputs.pop("sample", None)
1749+
common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
1750+
common_inputs["txt_ids"] = {0: "batch_size", 1: "sequence_length"}
1751+
common_inputs["img_ids"] = {0: "batch_size", 1: "packed_height_width"}
1752+
if getattr(self._normalized_config, "guidance_embeds", False):
1753+
common_inputs["guidance"] = {0: "batch_size"}
1754+
return common_inputs
1755+
1756+
def patch_model_for_export(
1757+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1758+
) -> ModelPatcher:
1759+
return FluxTransfromerModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+42-6
Original file line numberDiff line numberDiff line change
@@ -411,9 +411,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
411411
offset = 0
412412
mask_shape = attention_mask.shape
413413
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
414-
causal_mask[
415-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
416-
] = mask_slice
414+
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
415+
mask_slice
416+
)
417417

418418
if (
419419
self.config._attn_implementation == "sdpa"
@@ -1979,9 +1979,9 @@ def _dbrx_update_causal_mask_legacy(
19791979
offset = 0
19801980
mask_shape = attention_mask.shape
19811981
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1982-
causal_mask[
1983-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1984-
] = mask_slice
1982+
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
1983+
mask_slice
1984+
)
19851985

19861986
if (
19871987
self.config._attn_implementation == "sdpa"
@@ -2705,3 +2705,39 @@ def __init__(
27052705
def __exit__(self, exc_type, exc_value, traceback):
27062706
super().__exit__(exc_type, exc_value, traceback)
27072707
self._model.forward = self._model.__orig_forward
2708+
2709+
2710+
def _embednb_forward(self, ids: torch.Tensor) -> torch.Tensor:
2711+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
2712+
assert dim % 2 == 0, "The dimension must be even."
2713+
2714+
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
2715+
omega = 1.0 / (theta**scale)
2716+
2717+
batch_size, seq_length = pos.shape
2718+
out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0)
2719+
cos_out = torch.cos(out)
2720+
sin_out = torch.sin(out)
2721+
2722+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
2723+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
2724+
return out.float()
2725+
2726+
n_axes = ids.shape[-1]
2727+
emb = torch.cat(
2728+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
2729+
dim=-3,
2730+
)
2731+
return emb.unsqueeze(1)
2732+
2733+
2734+
class FluxTransfromerModelPatcher(ModelPatcher):
2735+
def __enter__(self):
2736+
super().__enter__()
2737+
self._model.pos_embed._orig_forward = self._model.pos_embed.forward
2738+
self._model.pos_embed.forward = types.MethodType(_embednb_forward, self._model.pos_embed)
2739+
2740+
def __exit__(self, exc_type, exc_value, traceback):
2741+
super().__exit__(exc_type, exc_value, traceback)
2742+
2743+
self._model.pos_embed.forward = self._model.pos_embed._orig_forward

optimum/intel/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
"OVStableDiffusion3InpaintPipeline",
106106
"OVLatentConsistencyModelPipeline",
107107
"OVLatentConsistencyModelImg2ImgPipeline",
108+
"OVFluxPipeline",
108109
"OVPipelineForImage2Image",
109110
"OVPipelineForText2Image",
110111
"OVPipelineForInpainting",
@@ -124,6 +125,7 @@
124125
"OVStableDiffusion3InpaintPipeline",
125126
"OVLatentConsistencyModelPipeline",
126127
"OVLatentConsistencyModelImg2ImgPipeline",
128+
"OVFluxPipeline",
127129
"OVPipelineForImage2Image",
128130
"OVPipelineForText2Image",
129131
"OVPipelineForInpainting",

optimum/intel/openvino/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
if is_diffusers_available():
8383
from .modeling_diffusion import (
8484
OVDiffusionPipeline,
85+
OVFluxPipeline,
8586
OVLatentConsistencyModelImg2ImgPipeline,
8687
OVLatentConsistencyModelPipeline,
8788
OVPipelineForImage2Image,

0 commit comments

Comments
 (0)