Skip to content

Commit 6bdc704

Browse files
[Image Generation] Image2Image for FLUX (openvinotoolkit#1621)
![img2img_flux_lite](https://github.com/user-attachments/assets/00d860c9-5e1a-46c3-8403-12e47e20d6b3) ![img2img_flux_dev](https://github.com/user-attachments/assets/13b966f4-6753-45b9-9a3f-2d5f6928f895) ![img2img_flux_schnell](https://github.com/user-attachments/assets/b6b675f5-1e37-4390-a678-b69c77bedc61) --------- Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
1 parent cc3b65a commit 6bdc704

12 files changed

+190
-70
lines changed

SUPPORTED_MODELS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ The pipeline can work with other similar topologies produced by `optimum-intel`
242242
<tr>
243243
<td><code>Flux</code></td>
244244
<td>Supported</td>
245-
<td>Not supported</td>
245+
<td>Supported</td>
246246
<td>Not supported</td>
247247
<td>
248248
<ul>

src/cpp/include/openvino/genai/image_generation/image2image_pipeline.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ class OPENVINO_GENAI_EXPORTS Image2ImagePipeline {
4949
const UNet2DConditionModel& unet,
5050
const AutoencoderKL& vae);
5151

52+
// creates Flux pipeline from building blocks
53+
static Image2ImagePipeline flux(
54+
const std::shared_ptr<Scheduler>& scheduler,
55+
const CLIPTextModel& clip_text_model,
56+
const T5EncoderModel t5_encoder_model,
57+
const FluxTransformer2DModel& transformer,
58+
const AutoencoderKL& vae);
59+
5260
ImageGenerationConfig get_generation_config() const;
5361
void set_generation_config(const ImageGenerationConfig& generation_config);
5462

src/cpp/src/image_generation/flux_pipeline.hpp

+85-65
Large diffs are not rendered by default.

src/cpp/src/image_generation/image2image_pipeline.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "image_generation/stable_diffusion_pipeline.hpp"
1111
#include "image_generation/stable_diffusion_xl_pipeline.hpp"
12+
#include "image_generation/flux_pipeline.hpp"
1213

1314
#include "utils.hpp"
1415

@@ -22,6 +23,8 @@ Image2ImagePipeline::Image2ImagePipeline(const std::filesystem::path& root_dir)
2223
m_impl = std::make_shared<StableDiffusionPipeline>(PipelineType::IMAGE_2_IMAGE, root_dir);
2324
} else if (class_name == "StableDiffusionXLPipeline") {
2425
m_impl = std::make_shared<StableDiffusionXLPipeline>(PipelineType::IMAGE_2_IMAGE, root_dir);
26+
} else if (class_name == "FluxPipeline") {
27+
m_impl = std::make_shared<FluxPipeline>(PipelineType::IMAGE_2_IMAGE, root_dir);
2528
} else {
2629
OPENVINO_THROW("Unsupported image to image generation pipeline '", class_name, "'");
2730
}
@@ -34,6 +37,8 @@ Image2ImagePipeline::Image2ImagePipeline(const std::filesystem::path& root_dir,
3437
m_impl = std::make_shared<StableDiffusionPipeline>(PipelineType::IMAGE_2_IMAGE, root_dir, device, properties);
3538
} else if (class_name == "StableDiffusionXLPipeline") {
3639
m_impl = std::make_shared<StableDiffusionXLPipeline>(PipelineType::IMAGE_2_IMAGE, root_dir, device, properties);
40+
} else if (class_name == "FluxPipeline") {
41+
m_impl = std::make_shared<FluxPipeline>(PipelineType::IMAGE_2_IMAGE, root_dir, device, properties);
3742
} else {
3843
OPENVINO_THROW("Unsupported image to image generation pipeline '", class_name, "'");
3944
}
@@ -44,6 +49,8 @@ Image2ImagePipeline::Image2ImagePipeline(const InpaintingPipeline& pipe) {
4449
m_impl = std::make_shared<StableDiffusionXLPipeline>(PipelineType::IMAGE_2_IMAGE, *stable_diffusion_xl);
4550
} else if (auto stable_diffusion = std::dynamic_pointer_cast<StableDiffusionPipeline>(pipe.m_impl); stable_diffusion != nullptr) {
4651
m_impl = std::make_shared<StableDiffusionPipeline>(PipelineType::IMAGE_2_IMAGE, *stable_diffusion);
52+
} else if (auto flux = std::dynamic_pointer_cast<FluxPipeline>(pipe.m_impl); flux != nullptr) {
53+
m_impl = std::make_shared<FluxPipeline>(PipelineType::IMAGE_2_IMAGE, *flux);
4754
} else {
4855
OPENVINO_ASSERT("Cannot convert specified InpaintingPipeline to Image2ImagePipeline");
4956
}
@@ -94,6 +101,20 @@ Image2ImagePipeline Image2ImagePipeline::stable_diffusion_xl(
94101
return Image2ImagePipeline(impl);
95102
}
96103

104+
Image2ImagePipeline Image2ImagePipeline::flux(
105+
const std::shared_ptr<Scheduler>& scheduler,
106+
const CLIPTextModel& clip_text_model,
107+
const T5EncoderModel t5_encoder_model,
108+
const FluxTransformer2DModel& transformer,
109+
const AutoencoderKL& vae){
110+
auto impl = std::make_shared<FluxPipeline>(PipelineType::IMAGE_2_IMAGE, clip_text_model, t5_encoder_model, transformer, vae);
111+
112+
assert(scheduler != nullptr);
113+
impl->set_scheduler(scheduler);
114+
115+
return Image2ImagePipeline(impl);
116+
}
117+
97118
ImageGenerationConfig Image2ImagePipeline::get_generation_config() const {
98119
return m_impl->get_generation_config();
99120
}

src/cpp/src/image_generation/image_processor.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ IImageProcessor::IImageProcessor(const std::string& device) :
3232
}
3333

3434
ov::Tensor IImageProcessor::execute(ov::Tensor image) {
35+
OPENVINO_ASSERT(m_request, "ImageProcessor model must be compiled first. Cannot infer non-compiled model");
3536
m_request.set_input_tensor(image);
3637
m_request.infer();
3738
return m_request.get_output_tensor();
@@ -124,6 +125,7 @@ ImageResizer::ImageResizer(const std::string& device, ov::element::Type type, ov
124125
}
125126

126127
ov::Tensor ImageResizer::execute(ov::Tensor image, int64_t dst_height, int64_t dst_width) {
128+
OPENVINO_ASSERT(m_request, "ImageResizer model must be compiled first. Cannot infer non-compiled model");
127129
ov::Tensor target_spatial_tensor(ov::element::i64, ov::Shape{2});
128130
target_spatial_tensor.data<int64_t>()[0] = dst_height;
129131
target_spatial_tensor.data<int64_t>()[1] = dst_width;

src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ std::map<std::string, ov::Tensor> EulerAncestralDiscreteScheduler::step(ov::Tens
208208
return {{"latent", prev_sample}, {"denoised", pred_original_sample}};
209209
}
210210

211-
size_t EulerAncestralDiscreteScheduler::_index_for_timestep(int64_t timestep) const{
211+
size_t EulerAncestralDiscreteScheduler::_index_for_timestep(int64_t timestep) const {
212212
for (size_t i = 0; i < m_schedule_timesteps.size(); ++i) {
213213
if (timestep == m_schedule_timesteps[i]) {
214214
return i;

src/cpp/src/image_generation/schedulers/flow_match_euler_discrete.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,43 @@ void FlowMatchEulerDiscreteScheduler::add_noise(ov::Tensor init_latent, ov::Tens
146146
OPENVINO_THROW("Not implemented");
147147
}
148148

149+
size_t FlowMatchEulerDiscreteScheduler::_index_for_timestep(float timestep) {
150+
if (m_schedule_timesteps.empty()) {
151+
m_schedule_timesteps = m_timesteps;
152+
}
153+
154+
for (size_t i = 0; i < m_schedule_timesteps.size(); ++i) {
155+
if (timestep == m_schedule_timesteps[i]) {
156+
return i;
157+
}
158+
}
159+
160+
OPENVINO_THROW("Failed to find index for timestep ", timestep);
161+
}
162+
163+
void FlowMatchEulerDiscreteScheduler::scale_noise(ov::Tensor sample, float timestep, ov::Tensor noise) {
164+
OPENVINO_ASSERT(timestep == -1, "Timestep is not computed yet");
165+
166+
size_t index_for_timestep;
167+
if (m_begin_index == -1) {
168+
index_for_timestep = _index_for_timestep(timestep);
169+
} else if (m_step_index != -1) {
170+
index_for_timestep = m_step_index;
171+
} else {
172+
index_for_timestep = m_begin_index;
173+
}
174+
175+
const float sigma = m_sigmas[index_for_timestep];
176+
177+
float * sample_data = sample.data<float>();
178+
const float * noise_data = noise.data<float>();
179+
180+
for (size_t i = 0; i < sample.get_size(); ++i) {
181+
sample_data[i] = sigma * noise_data[i] + (1.0f - sigma) * sample_data[i];
182+
}
183+
184+
}
185+
149186
void FlowMatchEulerDiscreteScheduler::set_timesteps_with_sigma(std::vector<float> sigma, float mu) {
150187
m_timesteps.clear();
151188
m_sigmas.clear();
@@ -184,5 +221,13 @@ float FlowMatchEulerDiscreteScheduler::calculate_shift(size_t image_seq_len) {
184221
return mu;
185222
}
186223

224+
void FlowMatchEulerDiscreteScheduler::set_begin_index(size_t begin_index) {
225+
m_begin_index = begin_index;
226+
}
227+
228+
size_t FlowMatchEulerDiscreteScheduler::get_begin_index() {
229+
return m_begin_index;
230+
}
231+
187232
} // namespace genai
188233
} // namespace ov

src/cpp/src/image_generation/schedulers/flow_match_euler_discrete.hpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,27 @@ class FlowMatchEulerDiscreteScheduler : public IScheduler {
4242

4343
void add_noise(ov::Tensor init_latent, ov::Tensor noise, int64_t latent_timestep) const override;
4444

45+
void scale_noise(ov::Tensor sample, float timestep, ov::Tensor noise) override;
46+
4547
float calculate_shift(size_t image_seq_len) override;
4648

49+
void set_begin_index(size_t begin_index) override;
50+
51+
size_t get_begin_index() override;
52+
4753
private:
4854
Config m_config;
4955

5056
std::vector<float> m_sigmas;
51-
std::vector<float> m_timesteps;
57+
std::vector<float> m_timesteps, m_schedule_timesteps;
5258

5359
float m_sigma_min, m_sigma_max;
5460
size_t m_step_index, m_begin_index;
5561
size_t m_num_inference_steps;
5662

5763
void init_step_index();
5864
double sigma_to_t(double simga);
65+
size_t _index_for_timestep(float timestep);
5966
};
6067

6168
} // namespace genai

src/cpp/src/image_generation/schedulers/ischeduler.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ class IScheduler : public Scheduler {
4343
virtual std::vector<float> get_float_timesteps() const {
4444
OPENVINO_THROW("Scheduler doesn't support float timesteps");
4545
}
46+
47+
virtual void scale_noise(ov::Tensor sample, float timestep, ov::Tensor noise) {
48+
OPENVINO_THROW("Scheduler doesn't support `scale_noise` method");
49+
}
50+
51+
virtual void set_begin_index(size_t begin_index) {};
52+
53+
virtual size_t get_begin_index() {
54+
OPENVINO_THROW("Scheduler doesn't support `get_begin_index` method");
55+
}
4656
};
4757

4858
} // namespace genai

src/python/openvino_genai/py_openvino_genai.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,9 @@ class Image2ImagePipeline:
774774
This class is used for generation with image-to-image models.
775775
"""
776776
@staticmethod
777+
def flux(scheduler: Scheduler, clip_text_model: CLIPTextModel, t5_encoder_model: T5EncoderModel, transformer: FluxTransformer2DModel, vae: AutoencoderKL) -> Image2ImagePipeline:
778+
...
779+
@staticmethod
777780
def latent_consistency_model(scheduler: Scheduler, clip_text_model: CLIPTextModel, unet: UNet2DConditionModel, vae: AutoencoderKL) -> Image2ImagePipeline:
778781
...
779782
@staticmethod

src/python/py_image_generation_pipelines.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ void init_image_generation_pipelines(py::module_& m) {
330330
.def_static("stable_diffusion", &ov::genai::Image2ImagePipeline::stable_diffusion, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae"))
331331
.def_static("latent_consistency_model", &ov::genai::Image2ImagePipeline::latent_consistency_model, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae"))
332332
.def_static("stable_diffusion_xl", &ov::genai::Image2ImagePipeline::stable_diffusion_xl, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("clip_text_model_with_projection"), py::arg("unet"), py::arg("vae"))
333+
.def_static("flux", &ov::genai::Image2ImagePipeline::flux, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("t5_encoder_model"), py::arg("transformer"), py::arg("vae"))
333334
.def(
334335
"compile",
335336
[](ov::genai::Image2ImagePipeline& pipe,

tools/who_what_benchmark/tests/test_cli_image.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,11 @@ def test_image_model_types(model_id, model_type, backend):
103103
])),
104104
)
105105
def test_image_model_genai(model_id, model_type):
106-
if ("flux" in model_id or "stable-diffusion-3" in model_id) and model_type != "text-to-image":
107-
pytest.skip(reason="FLUX or SD3 are supported as text to image only")
106+
if ("stable-diffusion-3" in model_id) and model_type != "text-to-image":
107+
pytest.skip(reason="SD3 is supported as text to image only")
108+
109+
if ("flux" in model_id) and model_type == "image-inpainting":
110+
pytest.skip(reason="FLUX is not yet supported as image inpainting")
108111

109112
with tempfile.TemporaryDirectory() as temp_dir:
110113
GT_FILE = os.path.join(temp_dir, "gt.csv")

0 commit comments

Comments
 (0)