Skip to content

Commit 33fc7b9

Browse files
Merge branch 'master' into at/uniform-llm
2 parents bb56540 + 91dc0ce commit 33fc7b9

29 files changed

+491
-529
lines changed

samples/cpp/image_generation/heterogeneous_stable_diffusion.cpp

+17-55
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ int32_t main(int32_t argc, char* argv[]) try {
1818

1919
const int width = 512;
2020
const int height = 512;
21-
const float guidance_scale = 7.5f;
2221
const int number_of_images_to_generate = 1;
2322
const int number_of_inference_steps_per_image = 20;
2423

@@ -37,73 +36,36 @@ int32_t main(int32_t argc, char* argv[]) try {
3736
std::string ov_cache_dir = "./cache";
3837

3938
//
40-
// Step 1: Prepare each Text2Image subcomponent (scheduler, text encoder, unet, vae) separately.
39+
// Step 1: Create the initial Text2ImagePipeline, given the model path
4140
//
41+
ov::genai::Text2ImagePipeline pipe(models_path);
4242

43-
// Create the scheduler from the details listed in the json.
44-
auto scheduler = ov::genai::Scheduler::from_config(root_dir / "scheduler/scheduler_config.json");
45-
46-
// Note that we could have created the scheduler by specifying specific type (for example EULER_DISCRETE), like
47-
// this: auto scheduler = ov::genai::Scheduler::from_config(root_dir / "scheduler/scheduler_config.json",
48-
// ov::genai::Scheduler::Type::EULER_DISCRETE);
49-
// This can be useful when a particular type of Scheduler is not yet supported natively by OpenVINO GenAI.
50-
// (even though we are actively working to support most commonly used ones)
51-
52-
// Create unet object
53-
auto unet = ov::genai::UNet2DConditionModel(root_dir / "unet");
54-
55-
// Set batch size based on classifier free guidance condition.
56-
int unet_batch_size = unet.do_classifier_free_guidance(guidance_scale) ? 2 : 1;
57-
58-
// Create the text encoder.
59-
auto text_encoder = ov::genai::CLIPTextModel(root_dir / "text_encoder");
60-
61-
// In case of NPU, we need to reshape the model to have static shapes
62-
if (text_encoder_device == "NPU") {
63-
text_encoder.reshape(unet_batch_size);
64-
}
65-
66-
// Compile text encoder for the specified device
67-
text_encoder.compile(text_encoder_device, ov::cache_dir(ov_cache_dir));
68-
69-
// In case of NPU, we need to reshape the model to have static shapes
70-
if (unet_device == "NPU") {
71-
// The max_postiion_embeddings config from text encoder will be used as a parameter to unet reshape.
72-
int max_position_embeddings = text_encoder.get_config().max_position_embeddings;
73-
74-
unet.reshape(unet_batch_size, height, width, max_position_embeddings);
75-
}
76-
77-
// Compile unet for specified device
78-
unet.compile(unet_device, ov::cache_dir(ov_cache_dir));
43+
//
44+
// Step 2: Reshape the pipeline given number of images, width, height, and guidance scale.
45+
//
46+
pipe.reshape(1, width, height, pipe.get_generation_config().guidance_scale);
7947

80-
// Create the vae decoder.
81-
auto vae = ov::genai::AutoencoderKL(root_dir / "vae_decoder");
48+
//
49+
// Step 3: Compile the pipeline with the specified devices, and properties (like cache dir)
50+
//
51+
ov::AnyMap properties = {ov::cache_dir(ov_cache_dir)};
8252

83-
// In case of NPU, we need to reshape the model to have static shapes
84-
if (vae_decoder_device == "NPU") {
85-
// We set batch-size to '1' here, as we're configuring our pipeline to return 1 image per 'generate' call.
86-
vae.reshape(1, height, width);
87-
}
53+
// Note that if there are device-specific properties that are needed, they can
54+
// be added using ov::device::properties groups, like this:
55+
//ov::AnyMap properties = {ov::device::properties("CPU", ov::cache_dir("cpu_cache")),
56+
// ov::device::properties("GPU", ov::cache_dir("gpu_cache")),
57+
// ov::device::properties("NPU", ov::cache_dir("npu_cache"))};
8858

89-
// Compile vae decoder for the specified device
90-
vae.compile(vae_decoder_device, ov::cache_dir(ov_cache_dir));
59+
pipe.compile(text_encoder_device, unet_device, vae_decoder_device, properties);
9160

92-
//
93-
// Step 2: Create a Text2ImagePipeline from the individual subcomponents
94-
//
95-
auto pipe = ov::genai::Text2ImagePipeline::stable_diffusion(scheduler, text_encoder, unet, vae);
9661

9762
//
98-
// Step 3: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
63+
// Step 4: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
9964
//
10065
for (int imagei = 0; imagei < number_of_images_to_generate; imagei++) {
10166
std::cout << "Generating image " << imagei << std::endl;
10267

10368
ov::Tensor image = pipe.generate(prompt,
104-
ov::genai::width(width),
105-
ov::genai::height(height),
106-
ov::genai::guidance_scale(guidance_scale),
10769
ov::genai::num_inference_steps(number_of_inference_steps_per_image),
10870
ov::genai::callback(progress_bar));
10971

samples/python/image_generation/heterogeneous_stable_diffusion.py

+21-53
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def main():
2323

2424
width = 512
2525
height = 512
26-
guidance_scale = 7.5
2726
number_of_images_to_generate = 1
2827
number_of_inference_steps_per_image = 20
2928

@@ -36,72 +35,41 @@ def main():
3635
ov_cache_dir = "./cache"
3736

3837
#
39-
# Step 1: Prepare each Text2Image subcomponent (scheduler, text encoder, unet, vae) separately.
38+
# Step 1: Create the initial Text2ImagePipeline, given the model path
4039
#
40+
pipe = openvino_genai.Text2ImagePipeline(args.model_dir)
4141

42-
# Create the scheduler from the details listed in the json.
43-
scheduler = openvino_genai.Scheduler.from_config(args.model_dir + "/scheduler/scheduler_config.json")
44-
45-
# Note that we can also create the scheduler by specifying specific type (for example EULER_DISCRETE), like this:
46-
# scheduler = openvino_genai.Scheduler.from_config(args.model_dir + "/scheduler/scheduler_config.json",
47-
# openvino_genai.Scheduler.Type.EULER_DISCRETE)
48-
# This can be useful when a particular type of Scheduler is not yet supported natively by OpenVINO GenAI.
49-
# (even though we are actively working to support most commonly used ones)
50-
51-
# Create unet object
52-
unet = openvino_genai.UNet2DConditionModel(args.model_dir + "/unet")
53-
54-
# Set batch size based on classifier free guidance condition.
55-
unet_batch_size = 2 if unet.do_classifier_free_guidance(guidance_scale) else 1
56-
57-
# Create the text encoder
58-
text_encoder = openvino_genai.CLIPTextModel(args.model_dir + "/text_encoder")
59-
60-
# In case of NPU, we need to reshape the model to have static shapes
61-
if args.text_encoder_device == "NPU":
62-
text_encoder.reshape(unet_batch_size)
63-
64-
# Compile text encoder for the specified device
65-
text_encoder.compile(args.text_encoder_device, CACHE_DIR=ov_cache_dir)
66-
67-
# In case of NPU, we need to reshape the unet model to have static shapes
68-
if args.unet_device == "NPU":
69-
# The max_postion_embeddings config from text encoder will be used as a parameter to unet reshape.
70-
max_position_embeddings = text_encoder.get_config().max_position_embeddings
71-
72-
unet.reshape(unet_batch_size, height, width, max_position_embeddings)
73-
74-
# Compile unet for specified device
75-
unet.compile(args.unet_device, CACHE_DIR=ov_cache_dir)
76-
77-
# Create the decoder
78-
vae = openvino_genai.AutoencoderKL(args.model_dir + "/vae_decoder")
79-
80-
# In case of NPU, we need to reshape the vae model to have static shapes
81-
if args.vae_decoder_device == "NPU":
82-
vae.reshape(1, height, width)
83-
84-
# Compile vae decoder for the specified device
85-
vae.compile(args.vae_decoder_device, CACHE_DIR=ov_cache_dir)
42+
#
43+
# Step 2: Reshape the pipeline given number of images, width, height, and guidance scale.
44+
#
45+
pipe.reshape(1, width, height, pipe.get_generation_config().guidance_scale)
8646

8747
#
88-
# Step 2: Create a Text2ImagePipeline from the individual subcomponents
48+
# Step 3: Compile the pipeline given the specified devices, and properties (like cache dir)
8949
#
50+
properties = {"CACHE_DIR": "cache"}
51+
52+
# Note that if there are device-specific properties that are needed, they can
53+
# be added using a "DEVICE_PROPERTIES" entry, like this:
54+
#properties = {
55+
# "DEVICE_PROPERTIES":
56+
# {
57+
# "CPU": {"CACHE_DIR": "cpu_cache"},
58+
# "GPU": {"CACHE_DIR": "gpu_cache"},
59+
# "NPU": {"CACHE_DIR": "npu_cache"}
60+
# }
61+
#}
9062

91-
pipe = openvino_genai.Text2ImagePipeline.stable_diffusion(scheduler, text_encoder, unet, vae)
63+
pipe.compile(args.text_encoder_device, args.unet_device, args.vae_decoder_device, config=properties)
9264

9365
#
94-
# Step 3: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
66+
# Step 4: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
9567
#
9668

9769
for imagei in range(0, number_of_images_to_generate):
9870
image_tensor = pipe.generate(
9971
args.prompt,
100-
width=width,
101-
height=height,
102-
guidance_scale=guidance_scale,
10372
num_inference_steps=number_of_inference_steps_per_image,
104-
num_images_per_prompt=1
10573
)
10674

10775
image = Image.fromarray(image_tensor.data[0])

src/cpp/include/openvino/genai/image_generation/clip_text_model.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class OPENVINO_GENAI_EXPORTS CLIPTextModel {
9595
std::shared_ptr<ov::Model> m_model;
9696

9797
Tokenizer m_clip_tokenizer;
98+
99+
bool m_slice_batch1_output = false;
98100
};
99101

100102
} // namespace genai

src/cpp/include/openvino/genai/image_generation/clip_text_model_with_projection.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class OPENVINO_GENAI_EXPORTS CLIPTextModelWithProjection {
9595
std::shared_ptr<ov::Model> m_model;
9696

9797
Tokenizer m_clip_tokenizer;
98+
99+
bool m_slice_batch1_output = false;
98100
};
99101

100102
} // namespace genai

src/cpp/include/openvino/genai/image_generation/text2image_pipeline.hpp

+23
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,36 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline {
192192
*/
193193
void compile(const std::string& device, const ov::AnyMap& properties = {});
194194

195+
/**
196+
* Compiles image generation pipeline for given devices for text encoding, denoising, and vae decoding.
197+
* @param text_encode_device A device to compile text encoder(s) with
198+
* @param denoise_device A device to compile denoiser (e.g. UNet, SD3 Transformer, etc.) with
199+
* @param vae_device A device to compile VAE decoder(s) with
200+
* @param properties A map of properties which affect models compilation
201+
* @note If pipeline was compiled before, an exception is thrown.
202+
*/
203+
void compile(const std::string& text_encode_device,
204+
const std::string& denoise_device,
205+
const std::string& vae_device,
206+
const ov::AnyMap& properties = {});
207+
195208
template <typename... Properties>
196209
ov::util::EnableIfAllStringAny<void, Properties...> compile(
197210
const std::string& device,
198211
Properties&&... properties) {
199212
return compile(device, ov::AnyMap{std::forward<Properties>(properties)...});
200213
}
201214

215+
template <typename... Properties>
216+
ov::util::EnableIfAllStringAny<void, Properties...> compile(const std::string& text_encode_device,
217+
const std::string& denoise_device,
218+
const std::string& vae_device,
219+
Properties&&... properties) {
220+
return compile(text_encode_device,
221+
denoise_device,
222+
vae_device, ov::AnyMap{std::forward<Properties>(properties)...});
223+
}
224+
202225
/**
203226
* Generates image(s) based on prompt and other image generation parameters
204227
* @param positive_prompt Prompt to generate image(s) from

src/cpp/src/icontinuous_batching.cpp

+5-12
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,6 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
147147
OPENVINO_ASSERT(prompts.size() == sampling_params.size(), "Number of prompts should be equal to the number of generation configs.");
148148
OPENVINO_ASSERT(prompts.size() == rgbs_vector.size(), "Number of prompts should be equal to the number of images vectors.");
149149

150-
for (auto config: sampling_params) {
151-
// If eos_token_id was not provided, take value from default m_generation_config
152-
if (config.eos_token_id == -1) {
153-
config.set_eos_token_id(m_generation_config.eos_token_id);
154-
}
155-
if (config.stop_token_ids.empty()) {
156-
config.stop_token_ids = m_generation_config.stop_token_ids;
157-
}
158-
config.validate();
159-
}
160-
161150
std::vector<ov::Tensor> input_embeds_list;
162151
for (size_t i = 0; i < prompts.size(); i++) {
163152
auto prompt = prompts[i];
@@ -187,7 +176,11 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::add_request(uint64_t re
187176
GenerationConfig sampling_params) {
188177
OPENVINO_ASSERT(m_model_input_type == ModelInputType::EMBEDDINGS, "Model doesn't support embeddings.");
189178
ov::genai::VLMPerfMetrics metrics;
190-
auto inputs = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, metrics);
179+
ov::Tensor inputs;
180+
{
181+
const std::lock_guard<std::mutex> lock(m_inputs_embedder_mutex);
182+
inputs = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, metrics);
183+
}
191184
return add_request(request_id, inputs, sampling_params);
192185
}
193186

src/cpp/src/icontinuous_batching.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class ContinuousBatchingPipeline::IContinuousBatchingPipeline {
5656

5757
ModelInputType m_model_input_type = ModelInputType::TOKENS;
5858
std::shared_ptr<InputsEmbedder> m_inputs_embedder;
59+
std::mutex m_inputs_embedder_mutex;
5960

6061
void stream_tokens(const std::shared_ptr<ThreadedStreamerWrapper>& streamer_ptr, const GenerationHandle& handle);
6162
public:

src/cpp/src/image_generation/diffusion_pipeline.hpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,15 @@ class DiffusionPipeline {
100100

101101
virtual void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) = 0;
102102

103-
virtual void compile(const std::string& device, const ov::AnyMap& properties) = 0;
103+
virtual void compile(const std::string& device, const ov::AnyMap& properties)
104+
{
105+
compile(device, device, device, properties);
106+
}
107+
108+
virtual void compile(const std::string& text_encode_device,
109+
const std::string& denoise_device,
110+
const std::string& vae_device,
111+
const ov::AnyMap& properties) = 0;
104112

105113
virtual std::tuple<ov::Tensor, ov::Tensor, ov::Tensor, ov::Tensor> prepare_latents(ov::Tensor initial_image, const ImageGenerationConfig& generation_config) = 0;
106114

src/cpp/src/image_generation/flux_pipeline.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ class FluxPipeline : public DiffusionPipeline {
256256
m_transformer->compile(device, *updated_properties);
257257
}
258258

259+
void compile(const std::string& text_encode_device,
260+
const std::string& denoise_device,
261+
const std::string& vae_device,
262+
const ov::AnyMap& properties) override {
263+
OPENVINO_THROW("not supported yet.");
264+
}
265+
259266
void compute_hidden_states(const std::string& positive_prompt, const ImageGenerationConfig& generation_config) override {
260267
// encode_prompt
261268
std::string prompt_2_str = generation_config.prompt_2 != std::nullopt ? *generation_config.prompt_2 : positive_prompt;

src/cpp/src/image_generation/models/clip_text_model.cpp

+32-4
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,26 @@ ov::Tensor CLIPTextModel::infer(const std::string& pos_prompt, const std::string
124124
}
125125
};
126126

127+
ov::PartialShape compiled_input_partial_shape = m_request.get_compiled_model().inputs()[0].get_partial_shape();
128+
127129
ov::Tensor input_ids = m_request.get_input_tensor();
128-
input_ids.set_shape({text_embedding_batch_size, m_config.max_position_embeddings});
130+
131+
if (compiled_input_partial_shape.is_dynamic()) {
132+
input_ids.set_shape({text_embedding_batch_size, m_config.max_position_embeddings});
133+
} else {
134+
auto compiled_input_shape = input_ids.get_shape();
135+
OPENVINO_ASSERT(compiled_input_shape.size() == 2, "CLIP text encoder model input must have rank of 2");
136+
OPENVINO_ASSERT(text_embedding_batch_size <= compiled_input_shape[0],
137+
"text_embedding_batch_size (", text_embedding_batch_size,
138+
") > CLIP text encoder model batch size (",compiled_input_shape[0], ").");
139+
OPENVINO_ASSERT(m_config.max_position_embeddings == compiled_input_shape[1],
140+
"max_position_embeddings (", m_config.max_position_embeddings,
141+
") != what CLIP text encoder model was compiled for (", compiled_input_shape[1], ").");
142+
}
129143

130144
size_t current_batch_idx = 0;
131145

132-
if (do_classifier_free_guidance) {
146+
if (input_ids.get_shape()[0] == 2) {
133147
perform_tokenization(neg_prompt,
134148
ov::Tensor(input_ids, {current_batch_idx , 0},
135149
{current_batch_idx + 1, m_config.max_position_embeddings}));
@@ -145,11 +159,25 @@ ov::Tensor CLIPTextModel::infer(const std::string& pos_prompt, const std::string
145159
// text embeddings
146160
m_request.infer();
147161

148-
return m_request.get_output_tensor(0);
162+
// This is true when text_embedding_batch_size is 1, but model was reshaped / compiled as batch size 2.
163+
m_slice_batch1_output = (text_embedding_batch_size != input_ids.get_shape()[0]);
164+
165+
return get_output_tensor(0);
149166
}
150167

151168
ov::Tensor CLIPTextModel::get_output_tensor(const size_t idx) {
152-
return m_request.get_output_tensor(idx);
169+
auto infer_out_tensor = m_request.get_output_tensor(idx);
170+
if (m_slice_batch1_output) {
171+
//Slice and return batch index 1 output.
172+
auto out_shape = infer_out_tensor.get_shape();
173+
auto begin_coord = ov::Coordinate(out_shape.size(), 0);
174+
begin_coord[0] = 1;
175+
auto end_coord = ov::Coordinate(out_shape);
176+
auto sliced_out_tensor = ov::Tensor(infer_out_tensor, begin_coord, end_coord);
177+
return sliced_out_tensor;
178+
} else {
179+
return infer_out_tensor;
180+
}
153181
}
154182

155183
} // namespace genai

0 commit comments

Comments
 (0)