Skip to content

Commit fae9029

Browse files
Text2ImagePipeline heterogenous compile (openvinotoolkit#1768)
To simplify creation of a heterogenous stable diffusion txt2image pipeline, this adds a new API to `Text2ImagePipeline` class: ``` /** * Compiles image generation pipeline for given devices for text encoding, denoising, and vae decoding. * @param text_encode_device A device to compile text encoder(s) with * @param denoise_device A device to compile denoiser (e.g. UNet, SD3 Transformer, etc.) with * @param vae_decode_device A device to compile VAE decoder(s) with * @param properties A map of properties which affect models compilation * @note If pipeline was compiled before, an exception is thrown. */ void compile(const std::string& text_encode_device, const std::string& denoise_device, const std::string& vae_decode_device, const ov::AnyMap& properties = {}); ``` (Need some feedback here.. especially on if we technically need 3 sets of properties.. one per device?) This API greatly simplifies heterogenous pipeline setup to this: ``` ov::genai::Text2ImagePipeline pipe(models_path); pipe.reshape(1, width, height, pipe.get_generation_config().guidance_scale); pipe.compile(text_encoder_device, unet_device, vae_decoder_device); ``` And so now with these changes, heterogenous stable diffusion sample can support all variants of stable diffusion (SD1.5, LCM, XL, SD3, etc.) with the same code. With the old method (creating sub-components and assembling pipeline object), it would have been difficult to achieve this. With that said, this PR is tested and working with the following pipelines (with NPU running denoise): * SD1.5 / LCM * SDXL TODO: * ~~Add python bindings for the new API~~ * ~~Update python heterogenous sample~~ **FUTURE WORK** (outside the scope of this PR): * Add support for SD3 (this will be separate PR) * In general, this requires fixes to this issue: openvinotoolkit/openvino#29113 * Also some weirdness in current reshape() path I need to figure out. * For NPU, this requires a 'batch 1' implementation for Transformer2D -- similar as we did for UNet. * Add support for FLUX (this will be separate PR) * Add equivalent API for IMAGE2IMAGE / INPAINTING (separate PR's) --------- Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
1 parent 10bfc72 commit fae9029

15 files changed

+255
-130
lines changed

samples/cpp/image_generation/heterogeneous_stable_diffusion.cpp

+17-55
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ int32_t main(int32_t argc, char* argv[]) try {
1818

1919
const int width = 512;
2020
const int height = 512;
21-
const float guidance_scale = 7.5f;
2221
const int number_of_images_to_generate = 1;
2322
const int number_of_inference_steps_per_image = 20;
2423

@@ -37,73 +36,36 @@ int32_t main(int32_t argc, char* argv[]) try {
3736
std::string ov_cache_dir = "./cache";
3837

3938
//
40-
// Step 1: Prepare each Text2Image subcomponent (scheduler, text encoder, unet, vae) separately.
39+
// Step 1: Create the initial Text2ImagePipeline, given the model path
4140
//
41+
ov::genai::Text2ImagePipeline pipe(models_path);
4242

43-
// Create the scheduler from the details listed in the json.
44-
auto scheduler = ov::genai::Scheduler::from_config(root_dir / "scheduler/scheduler_config.json");
45-
46-
// Note that we could have created the scheduler by specifying specific type (for example EULER_DISCRETE), like
47-
// this: auto scheduler = ov::genai::Scheduler::from_config(root_dir / "scheduler/scheduler_config.json",
48-
// ov::genai::Scheduler::Type::EULER_DISCRETE);
49-
// This can be useful when a particular type of Scheduler is not yet supported natively by OpenVINO GenAI.
50-
// (even though we are actively working to support most commonly used ones)
51-
52-
// Create unet object
53-
auto unet = ov::genai::UNet2DConditionModel(root_dir / "unet");
54-
55-
// Set batch size based on classifier free guidance condition.
56-
int unet_batch_size = unet.do_classifier_free_guidance(guidance_scale) ? 2 : 1;
57-
58-
// Create the text encoder.
59-
auto text_encoder = ov::genai::CLIPTextModel(root_dir / "text_encoder");
60-
61-
// In case of NPU, we need to reshape the model to have static shapes
62-
if (text_encoder_device == "NPU") {
63-
text_encoder.reshape(unet_batch_size);
64-
}
65-
66-
// Compile text encoder for the specified device
67-
text_encoder.compile(text_encoder_device, ov::cache_dir(ov_cache_dir));
68-
69-
// In case of NPU, we need to reshape the model to have static shapes
70-
if (unet_device == "NPU") {
71-
// The max_postiion_embeddings config from text encoder will be used as a parameter to unet reshape.
72-
int max_position_embeddings = text_encoder.get_config().max_position_embeddings;
73-
74-
unet.reshape(unet_batch_size, height, width, max_position_embeddings);
75-
}
76-
77-
// Compile unet for specified device
78-
unet.compile(unet_device, ov::cache_dir(ov_cache_dir));
43+
//
44+
// Step 2: Reshape the pipeline given number of images, width, height, and guidance scale.
45+
//
46+
pipe.reshape(1, width, height, pipe.get_generation_config().guidance_scale);
7947

80-
// Create the vae decoder.
81-
auto vae = ov::genai::AutoencoderKL(root_dir / "vae_decoder");
48+
//
49+
// Step 3: Compile the pipeline with the specified devices, and properties (like cache dir)
50+
//
51+
ov::AnyMap properties = {ov::cache_dir(ov_cache_dir)};
8252

83-
// In case of NPU, we need to reshape the model to have static shapes
84-
if (vae_decoder_device == "NPU") {
85-
// We set batch-size to '1' here, as we're configuring our pipeline to return 1 image per 'generate' call.
86-
vae.reshape(1, height, width);
87-
}
53+
// Note that if there are device-specific properties that are needed, they can
54+
// be added using ov::device::properties groups, like this:
55+
//ov::AnyMap properties = {ov::device::properties("CPU", ov::cache_dir("cpu_cache")),
56+
// ov::device::properties("GPU", ov::cache_dir("gpu_cache")),
57+
// ov::device::properties("NPU", ov::cache_dir("npu_cache"))};
8858

89-
// Compile vae decoder for the specified device
90-
vae.compile(vae_decoder_device, ov::cache_dir(ov_cache_dir));
59+
pipe.compile(text_encoder_device, unet_device, vae_decoder_device, properties);
9160

92-
//
93-
// Step 2: Create a Text2ImagePipeline from the individual subcomponents
94-
//
95-
auto pipe = ov::genai::Text2ImagePipeline::stable_diffusion(scheduler, text_encoder, unet, vae);
9661

9762
//
98-
// Step 3: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
63+
// Step 4: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
9964
//
10065
for (int imagei = 0; imagei < number_of_images_to_generate; imagei++) {
10166
std::cout << "Generating image " << imagei << std::endl;
10267

10368
ov::Tensor image = pipe.generate(prompt,
104-
ov::genai::width(width),
105-
ov::genai::height(height),
106-
ov::genai::guidance_scale(guidance_scale),
10769
ov::genai::num_inference_steps(number_of_inference_steps_per_image),
10870
ov::genai::callback(progress_bar));
10971

samples/python/image_generation/heterogeneous_stable_diffusion.py

+21-53
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def main():
2323

2424
width = 512
2525
height = 512
26-
guidance_scale = 7.5
2726
number_of_images_to_generate = 1
2827
number_of_inference_steps_per_image = 20
2928

@@ -36,72 +35,41 @@ def main():
3635
ov_cache_dir = "./cache"
3736

3837
#
39-
# Step 1: Prepare each Text2Image subcomponent (scheduler, text encoder, unet, vae) separately.
38+
# Step 1: Create the initial Text2ImagePipeline, given the model path
4039
#
40+
pipe = openvino_genai.Text2ImagePipeline(args.model_dir)
4141

42-
# Create the scheduler from the details listed in the json.
43-
scheduler = openvino_genai.Scheduler.from_config(args.model_dir + "/scheduler/scheduler_config.json")
44-
45-
# Note that we can also create the scheduler by specifying specific type (for example EULER_DISCRETE), like this:
46-
# scheduler = openvino_genai.Scheduler.from_config(args.model_dir + "/scheduler/scheduler_config.json",
47-
# openvino_genai.Scheduler.Type.EULER_DISCRETE)
48-
# This can be useful when a particular type of Scheduler is not yet supported natively by OpenVINO GenAI.
49-
# (even though we are actively working to support most commonly used ones)
50-
51-
# Create unet object
52-
unet = openvino_genai.UNet2DConditionModel(args.model_dir + "/unet")
53-
54-
# Set batch size based on classifier free guidance condition.
55-
unet_batch_size = 2 if unet.do_classifier_free_guidance(guidance_scale) else 1
56-
57-
# Create the text encoder
58-
text_encoder = openvino_genai.CLIPTextModel(args.model_dir + "/text_encoder")
59-
60-
# In case of NPU, we need to reshape the model to have static shapes
61-
if args.text_encoder_device == "NPU":
62-
text_encoder.reshape(unet_batch_size)
63-
64-
# Compile text encoder for the specified device
65-
text_encoder.compile(args.text_encoder_device, CACHE_DIR=ov_cache_dir)
66-
67-
# In case of NPU, we need to reshape the unet model to have static shapes
68-
if args.unet_device == "NPU":
69-
# The max_postion_embeddings config from text encoder will be used as a parameter to unet reshape.
70-
max_position_embeddings = text_encoder.get_config().max_position_embeddings
71-
72-
unet.reshape(unet_batch_size, height, width, max_position_embeddings)
73-
74-
# Compile unet for specified device
75-
unet.compile(args.unet_device, CACHE_DIR=ov_cache_dir)
76-
77-
# Create the decoder
78-
vae = openvino_genai.AutoencoderKL(args.model_dir + "/vae_decoder")
79-
80-
# In case of NPU, we need to reshape the vae model to have static shapes
81-
if args.vae_decoder_device == "NPU":
82-
vae.reshape(1, height, width)
83-
84-
# Compile vae decoder for the specified device
85-
vae.compile(args.vae_decoder_device, CACHE_DIR=ov_cache_dir)
42+
#
43+
# Step 2: Reshape the pipeline given number of images, width, height, and guidance scale.
44+
#
45+
pipe.reshape(1, width, height, pipe.get_generation_config().guidance_scale)
8646

8747
#
88-
# Step 2: Create a Text2ImagePipeline from the individual subcomponents
48+
# Step 3: Compile the pipeline given the specified devices, and properties (like cache dir)
8949
#
50+
properties = {"CACHE_DIR": "cache"}
51+
52+
# Note that if there are device-specific properties that are needed, they can
53+
# be added using a "DEVICE_PROPERTIES" entry, like this:
54+
#properties = {
55+
# "DEVICE_PROPERTIES":
56+
# {
57+
# "CPU": {"CACHE_DIR": "cpu_cache"},
58+
# "GPU": {"CACHE_DIR": "gpu_cache"},
59+
# "NPU": {"CACHE_DIR": "npu_cache"}
60+
# }
61+
#}
9062

91-
pipe = openvino_genai.Text2ImagePipeline.stable_diffusion(scheduler, text_encoder, unet, vae)
63+
pipe.compile(args.text_encoder_device, args.unet_device, args.vae_decoder_device, config=properties)
9264

9365
#
94-
# Step 3: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
66+
# Step 4: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
9567
#
9668

9769
for imagei in range(0, number_of_images_to_generate):
9870
image_tensor = pipe.generate(
9971
args.prompt,
100-
width=width,
101-
height=height,
102-
guidance_scale=guidance_scale,
10372
num_inference_steps=number_of_inference_steps_per_image,
104-
num_images_per_prompt=1
10573
)
10674

10775
image = Image.fromarray(image_tensor.data[0])

src/cpp/include/openvino/genai/image_generation/clip_text_model.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class OPENVINO_GENAI_EXPORTS CLIPTextModel {
9595
std::shared_ptr<ov::Model> m_model;
9696

9797
Tokenizer m_clip_tokenizer;
98+
99+
bool m_slice_batch1_output = false;
98100
};
99101

100102
} // namespace genai

src/cpp/include/openvino/genai/image_generation/clip_text_model_with_projection.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class OPENVINO_GENAI_EXPORTS CLIPTextModelWithProjection {
9595
std::shared_ptr<ov::Model> m_model;
9696

9797
Tokenizer m_clip_tokenizer;
98+
99+
bool m_slice_batch1_output = false;
98100
};
99101

100102
} // namespace genai

src/cpp/include/openvino/genai/image_generation/text2image_pipeline.hpp

+23
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,36 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline {
192192
*/
193193
void compile(const std::string& device, const ov::AnyMap& properties = {});
194194

195+
/**
196+
* Compiles image generation pipeline for given devices for text encoding, denoising, and vae decoding.
197+
* @param text_encode_device A device to compile text encoder(s) with
198+
* @param denoise_device A device to compile denoiser (e.g. UNet, SD3 Transformer, etc.) with
199+
* @param vae_device A device to compile VAE decoder(s) with
200+
* @param properties A map of properties which affect models compilation
201+
* @note If pipeline was compiled before, an exception is thrown.
202+
*/
203+
void compile(const std::string& text_encode_device,
204+
const std::string& denoise_device,
205+
const std::string& vae_device,
206+
const ov::AnyMap& properties = {});
207+
195208
template <typename... Properties>
196209
ov::util::EnableIfAllStringAny<void, Properties...> compile(
197210
const std::string& device,
198211
Properties&&... properties) {
199212
return compile(device, ov::AnyMap{std::forward<Properties>(properties)...});
200213
}
201214

215+
template <typename... Properties>
216+
ov::util::EnableIfAllStringAny<void, Properties...> compile(const std::string& text_encode_device,
217+
const std::string& denoise_device,
218+
const std::string& vae_device,
219+
Properties&&... properties) {
220+
return compile(text_encode_device,
221+
denoise_device,
222+
vae_device, ov::AnyMap{std::forward<Properties>(properties)...});
223+
}
224+
202225
/**
203226
* Generates image(s) based on prompt and other image generation parameters
204227
* @param positive_prompt Prompt to generate image(s) from

src/cpp/src/image_generation/diffusion_pipeline.hpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,15 @@ class DiffusionPipeline {
100100

101101
virtual void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) = 0;
102102

103-
virtual void compile(const std::string& device, const ov::AnyMap& properties) = 0;
103+
virtual void compile(const std::string& device, const ov::AnyMap& properties)
104+
{
105+
compile(device, device, device, properties);
106+
}
107+
108+
virtual void compile(const std::string& text_encode_device,
109+
const std::string& denoise_device,
110+
const std::string& vae_device,
111+
const ov::AnyMap& properties) = 0;
104112

105113
virtual std::tuple<ov::Tensor, ov::Tensor, ov::Tensor, ov::Tensor> prepare_latents(ov::Tensor initial_image, const ImageGenerationConfig& generation_config) = 0;
106114

src/cpp/src/image_generation/flux_pipeline.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ class FluxPipeline : public DiffusionPipeline {
256256
m_transformer->compile(device, *updated_properties);
257257
}
258258

259+
void compile(const std::string& text_encode_device,
260+
const std::string& denoise_device,
261+
const std::string& vae_device,
262+
const ov::AnyMap& properties) override {
263+
OPENVINO_THROW("not supported yet.");
264+
}
265+
259266
void compute_hidden_states(const std::string& positive_prompt, const ImageGenerationConfig& generation_config) override {
260267
// encode_prompt
261268
std::string prompt_2_str = generation_config.prompt_2 != std::nullopt ? *generation_config.prompt_2 : positive_prompt;

src/cpp/src/image_generation/models/clip_text_model.cpp

+32-4
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,26 @@ ov::Tensor CLIPTextModel::infer(const std::string& pos_prompt, const std::string
124124
}
125125
};
126126

127+
ov::PartialShape compiled_input_partial_shape = m_request.get_compiled_model().inputs()[0].get_partial_shape();
128+
127129
ov::Tensor input_ids = m_request.get_input_tensor();
128-
input_ids.set_shape({text_embedding_batch_size, m_config.max_position_embeddings});
130+
131+
if (compiled_input_partial_shape.is_dynamic()) {
132+
input_ids.set_shape({text_embedding_batch_size, m_config.max_position_embeddings});
133+
} else {
134+
auto compiled_input_shape = input_ids.get_shape();
135+
OPENVINO_ASSERT(compiled_input_shape.size() == 2, "CLIP text encoder model input must have rank of 2");
136+
OPENVINO_ASSERT(text_embedding_batch_size <= compiled_input_shape[0],
137+
"text_embedding_batch_size (", text_embedding_batch_size,
138+
") > CLIP text encoder model batch size (",compiled_input_shape[0], ").");
139+
OPENVINO_ASSERT(m_config.max_position_embeddings == compiled_input_shape[1],
140+
"max_position_embeddings (", m_config.max_position_embeddings,
141+
") != what CLIP text encoder model was compiled for (", compiled_input_shape[1], ").");
142+
}
129143

130144
size_t current_batch_idx = 0;
131145

132-
if (do_classifier_free_guidance) {
146+
if (input_ids.get_shape()[0] == 2) {
133147
perform_tokenization(neg_prompt,
134148
ov::Tensor(input_ids, {current_batch_idx , 0},
135149
{current_batch_idx + 1, m_config.max_position_embeddings}));
@@ -145,11 +159,25 @@ ov::Tensor CLIPTextModel::infer(const std::string& pos_prompt, const std::string
145159
// text embeddings
146160
m_request.infer();
147161

148-
return m_request.get_output_tensor(0);
162+
// This is true when text_embedding_batch_size is 1, but model was reshaped / compiled as batch size 2.
163+
m_slice_batch1_output = (text_embedding_batch_size != input_ids.get_shape()[0]);
164+
165+
return get_output_tensor(0);
149166
}
150167

151168
ov::Tensor CLIPTextModel::get_output_tensor(const size_t idx) {
152-
return m_request.get_output_tensor(idx);
169+
auto infer_out_tensor = m_request.get_output_tensor(idx);
170+
if (m_slice_batch1_output) {
171+
//Slice and return batch index 1 output.
172+
auto out_shape = infer_out_tensor.get_shape();
173+
auto begin_coord = ov::Coordinate(out_shape.size(), 0);
174+
begin_coord[0] = 1;
175+
auto end_coord = ov::Coordinate(out_shape);
176+
auto sliced_out_tensor = ov::Tensor(infer_out_tensor, begin_coord, end_coord);
177+
return sliced_out_tensor;
178+
} else {
179+
return infer_out_tensor;
180+
}
153181
}
154182

155183
} // namespace genai

0 commit comments

Comments
 (0)