Skip to content

Commit ed2baf4

Browse files
text2image: Add pimp approach for unet_2d_condition & batch-size 1 implementation of unet for initial NPU support (openvinotoolkit#1101)
For now, NPU can be used like this. The unet model must be reshaped to a static shape before compile is invoked. ```cpp std::filesystem::path root_dir = models_path; auto pipe = ov::genai::Text2ImagePipeline::stable_diffusion( ov::genai::Text2ImagePipeline::Scheduler::from_config(root_dir / "scheduler/scheduler_config.json"), ov::genai::CLIPTextModel(root_dir / "text_encoder", "CPU"), ov::genai::UNet2DConditionModel(root_dir / "unet") .reshape(2, 512, 512, 77) .compile("NPU", ov::cache_dir("./cache")), ov::genai::AutoencoderKL(root_dir / "vae_decoder", "GPU", ov::cache_dir("./cache"))); ov::Tensor image = pipe.generate(prompt); ```
2 parents e1430a5 + 91988ce commit ed2baf4

File tree

8 files changed

+439
-37
lines changed

8 files changed

+439
-37
lines changed

samples/cpp/text2image/CMakeLists.txt

+19
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,22 @@ install(TARGETS lora_stable_diffusion
4545
RUNTIME DESTINATION samples_bin/
4646
COMPONENT samples_bin
4747
EXCLUDE_FROM_ALL)
48+
49+
# create txt2image_from_subcomponent sample executable
50+
51+
add_executable(txt2image_from_subcomponent
52+
${CMAKE_CURRENT_SOURCE_DIR}/txt2image_from_subcomponent.cpp
53+
${CMAKE_CURRENT_SOURCE_DIR}/imwrite.cpp)
54+
55+
target_include_directories(txt2image_from_subcomponent PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
56+
target_link_libraries(txt2image_from_subcomponent PRIVATE openvino::genai)
57+
58+
set_target_properties(txt2image_from_subcomponent PROPERTIES
59+
COMPILE_PDB_NAME txt2image_from_subcomponent
60+
# Ensure out of box LC_RPATH on macOS with SIP
61+
INSTALL_RPATH_USE_LINK_PATH ON)
62+
63+
install(TARGETS txt2image_from_subcomponent
64+
RUNTIME DESTINATION samples_bin/
65+
COMPONENT samples_bin
66+
EXCLUDE_FROM_ALL)

samples/cpp/text2image/README.md

+16-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
Examples in this folder showcase inference of text to image models like Stable Diffusion 1.5, 2.1, LCM. The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::Text2ImagePipeline` and uses a text prompt as input source.
44

5-
There are two sample files:
5+
There are three sample files:
66
- [`main.cpp`](./main.cpp) demonstrates basic usage of the text to image pipeline
77
- [`lora.cpp`](./lora.cpp) shows how to apply LoRA adapters to the pipeline
8+
- [`txt2image_from_subcomponent.cpp`](./txt2image_from_subcomponent.cpp) shows how to assemble a txt2image pipeline from individual subcomponents (scheduler, text encoder, unet, vae decoder)
89

910
Users can change the sample code and play with the following generation parameters:
1011

@@ -67,3 +68,17 @@ With adapter | Without adapter
6768
- Image generated with HuggingFace / Optimum Intel is not the same generated by this C++ sample:
6869

6970
C++ random generation with MT19937 results differ from `numpy.random.randn()` and `diffusers.utils.randn_tensor`. So, it's expected that image generated by Python and C++ versions provide different images, because latent images are initialize differently. Users can implement their own random generator derived from `ov::genai::Generator` and pass it to `Text2ImagePipeline::generate` method.
71+
72+
## Run with multiple devices
73+
74+
The `txt2image_from_subcomponent` sample demonstrates how a Text2ImagePipeline object can be created from individual subcomponents - scheduler, text encoder, unet, & vae decoder. This approach gives fine-grained control over the devices used to execute each stage of the stable diffusion pipeline.
75+
76+
The usage of this sample is:
77+
78+
`txt2image_from_subcomponent <MODEL_DIR> '<PROMPT>' [ <TXT_ENCODE_DEVICE> <UNET_DEVICE> <VAE_DEVICE> ]`
79+
80+
For example:
81+
82+
`txt2image_from_subcomponent ./dreamlike_anime_1_0_ov/FP16 'cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting' CPU NPU GPU`
83+
84+
The sample will create a stable diffusion pipeline such that the text encoder is executed on the CPU, UNet on the NPU, and VAE decoder on the GPU.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "imwrite.hpp"
5+
#include "openvino/genai/image_generation/text2image_pipeline.hpp"
6+
7+
int32_t main(int32_t argc, char* argv[]) try {
8+
OPENVINO_ASSERT(argc >= 3 && argc <= 6,
9+
"Usage: ",
10+
argv[0],
11+
" <MODEL_DIR> '<PROMPT>' [ <TXT_ENCODE_DEVICE> <UNET_DEVICE> <VAE_DEVICE> ]");
12+
13+
const std::string models_path = argv[1], prompt = argv[2];
14+
15+
std::filesystem::path root_dir = models_path;
16+
17+
const int width = 512;
18+
const int height = 512;
19+
const float guidance_scale = 7.5f;
20+
const int number_of_images_to_generate = 1;
21+
const int number_of_inference_steps_per_image = 20;
22+
23+
// Set devices to command-line args if specified, otherwise default to CPU.
24+
// Note that these can be set to CPU, GPU, or NPU.
25+
const std::string text_encoder_device = (argc > 3) ? argv[3] : "CPU";
26+
const std::string unet_device = (argc > 4) ? argv[4] : "CPU";
27+
const std::string vae_decoder_device = (argc > 5) ? argv[5] : "CPU";
28+
29+
std::cout << "text_encoder_device: " << text_encoder_device << std::endl;
30+
std::cout << "unet_device: " << unet_device << std::endl;
31+
std::cout << "vae_decoder_device: " << vae_decoder_device << std::endl;
32+
33+
// this is the path to where compiled models will get cached
34+
// (so that the 'compile' method run much faster 2nd+ time)
35+
std::string ov_cache_dir = "./cache";
36+
37+
//
38+
// Step 1: Prepare each Text2Image subcomponent (scheduler, text encoder, unet, vae) separately.
39+
//
40+
41+
// Create the scheduler from the details listed in the json.
42+
auto scheduler = ov::genai::Scheduler::from_config(root_dir / "scheduler/scheduler_config.json");
43+
44+
// Note that we could have created the scheduler by specifying specific type (for example EULER_DISCRETE), like
45+
// this: auto scheduler = ov::genai::Scheduler::from_config(root_dir / "scheduler/scheduler_config.json",
46+
// ov::genai::Scheduler::Type::EULER_DISCRETE);
47+
48+
// Create unet object
49+
auto unet = ov::genai::UNet2DConditionModel(root_dir / "unet");
50+
51+
// Given the guidance scale, etc., calculate the batch size.
52+
int unet_batch_size = 1;
53+
if (guidance_scale > 1.0f && unet.get_config().time_cond_proj_dim < 0) {
54+
unet_batch_size = 2;
55+
}
56+
57+
// Create, reshape, and compile the text encoder.
58+
auto text_encoder = ov::genai::CLIPTextModel(root_dir / "text_encoder");
59+
text_encoder.reshape(unet_batch_size);
60+
text_encoder.compile(text_encoder_device, ov::cache_dir(ov_cache_dir));
61+
62+
// The max_postiion_embeddings config from text encoder will be used as a parameter to unet reshape.
63+
int max_position_embeddings = text_encoder.get_config().max_position_embeddings;
64+
65+
// Reshape unet to a static shape, and compile it.
66+
unet.reshape(unet_batch_size, height, width, max_position_embeddings);
67+
unet.compile(unet_device, ov::cache_dir(ov_cache_dir));
68+
69+
// Create, reshape, and compile the vae decoder.
70+
auto vae = ov::genai::AutoencoderKL(root_dir / "vae_decoder");
71+
vae.reshape(1, height, width); // We set batch-size to '1' here, as we're configuring our pipeline to return 1
72+
// image per 'generate' call.
73+
vae.compile(vae_decoder_device, ov::cache_dir(ov_cache_dir));
74+
75+
//
76+
// Step 2: Create a Text2ImagePipeline from the individual subcomponents
77+
//
78+
auto pipe = ov::genai::Text2ImagePipeline::stable_diffusion(scheduler, text_encoder, unet, vae);
79+
80+
//
81+
// Step 3: Use the Text2ImagePipeline to generate 'number_of_images_to_generate' images.
82+
//
83+
for (int imagei = 0; imagei < number_of_images_to_generate; imagei++) {
84+
std::cout << "Generating image " << imagei << std::endl;
85+
86+
ov::Tensor image = pipe.generate(prompt,
87+
ov::genai::width(width),
88+
ov::genai::height(height),
89+
ov::genai::guidance_scale(guidance_scale),
90+
ov::genai::num_inference_steps(number_of_inference_steps_per_image));
91+
92+
imwrite("image_" + std::to_string(imagei) + ".bmp", image, true);
93+
}
94+
95+
return EXIT_SUCCESS;
96+
} catch (const std::exception& error) {
97+
try {
98+
std::cerr << error.what() << '\n';
99+
} catch (const std::ios_base::failure&) {
100+
}
101+
return EXIT_FAILURE;
102+
} catch (...) {
103+
try {
104+
std::cerr << "Non-exception object thrown\n";
105+
} catch (const std::ios_base::failure&) {
106+
}
107+
return EXIT_FAILURE;
108+
}

src/cpp/include/openvino/genai/image_generation/unet2d_condition_model.hpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,16 @@ class OPENVINO_GENAI_EXPORTS UNet2DConditionModel {
6565
ov::Tensor infer(ov::Tensor sample, ov::Tensor timestep);
6666

6767
private:
68+
class UNetInference;
69+
std::shared_ptr<UNetInference> m_impl;
70+
6871
Config m_config;
6972
AdapterController m_adapter_controller;
7073
std::shared_ptr<ov::Model> m_model;
71-
ov::InferRequest m_request;
7274
size_t m_vae_scale_factor;
75+
76+
class UNetInferenceDynamic;
77+
class UNetInferenceStaticBS1;
7378
};
7479

7580
} // namespace genai

src/cpp/src/image_generation/models/unet2d_condition_model.cpp

+21-35
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
#include "openvino/genai/image_generation/unet2d_condition_model.hpp"
5+
#include "image_generation/models/unet_inference_dynamic.hpp"
6+
#include "image_generation/models/unet_inference_static_bs1.hpp"
57

68
#include <fstream>
79

@@ -52,67 +54,51 @@ UNet2DConditionModel& UNet2DConditionModel::reshape(int batch_size, int height,
5254
height /= m_vae_scale_factor;
5355
width /= m_vae_scale_factor;
5456

55-
std::map<std::string, ov::PartialShape> name_to_shape;
56-
57-
for (auto && input : m_model->inputs()) {
58-
std::string input_name = input.get_any_name();
59-
name_to_shape[input_name] = input.get_partial_shape();
60-
if (input_name == "timestep") {
61-
name_to_shape[input_name][0] = 1;
62-
} else if (input_name == "sample") {
63-
name_to_shape[input_name] = {batch_size, name_to_shape[input_name][1], height, width};
64-
} else if (input_name == "time_ids" || input_name == "text_embeds") {
65-
name_to_shape[input_name][0] = batch_size;
66-
} else if (input_name == "encoder_hidden_states") {
67-
name_to_shape[input_name][0] = batch_size;
68-
name_to_shape[input_name][1] = tokenizer_model_max_length;
69-
}
70-
}
71-
72-
m_model->reshape(name_to_shape);
57+
UNetInference::reshape(m_model, batch_size, height, width, tokenizer_model_max_length);
7358

7459
return *this;
7560
}
7661

7762
UNet2DConditionModel& UNet2DConditionModel::compile(const std::string& device, const ov::AnyMap& properties) {
7863
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot re-compile already compiled model");
79-
ov::Core core = utils::singleton_core();
80-
ov::CompiledModel compiled_model;
64+
65+
if (device == "NPU") {
66+
m_impl = std::make_shared<UNet2DConditionModel::UNetInferenceStaticBS1>();
67+
} else {
68+
m_impl = std::make_shared<UNet2DConditionModel::UNetInferenceDynamic>();
69+
}
70+
8171
std::optional<AdapterConfig> adapters;
8272
if (auto filtered_properties = extract_adapters_from_properties(properties, &adapters)) {
8373
adapters->set_tensor_name_prefix(adapters->get_tensor_name_prefix().value_or("lora_unet"));
8474
m_adapter_controller = AdapterController(m_model, *adapters, device);
85-
compiled_model = core.compile_model(m_model, device, *filtered_properties);
75+
m_impl->compile(m_model, device, *filtered_properties);
8676
} else {
87-
compiled_model = core.compile_model(m_model, device, properties);
77+
m_impl->compile(m_model, device, properties);
8878
}
89-
m_request = compiled_model.create_infer_request();
79+
9080
// release the original model
9181
m_model.reset();
9282

9383
return *this;
9484
}
9585

9686
void UNet2DConditionModel::set_hidden_states(const std::string& tensor_name, ov::Tensor encoder_hidden_states) {
97-
OPENVINO_ASSERT(m_request, "UNet model must be compiled first");
98-
m_request.set_tensor(tensor_name, encoder_hidden_states);
87+
OPENVINO_ASSERT(m_impl, "UNet model must be compiled first");
88+
m_impl->set_hidden_states(tensor_name, encoder_hidden_states);
9989
}
10090

10191
void UNet2DConditionModel::set_adapters(const std::optional<AdapterConfig>& adapters) {
102-
if (adapters) {
103-
m_adapter_controller.apply(m_request, *adapters);
92+
OPENVINO_ASSERT(m_impl, "UNet model must be compiled first");
93+
if(adapters) {
94+
OPENVINO_ASSERT(m_impl, "UNet model must be compiled first");
95+
m_impl->set_adapters(m_adapter_controller, *adapters);
10496
}
10597
}
10698

10799
ov::Tensor UNet2DConditionModel::infer(ov::Tensor sample, ov::Tensor timestep) {
108-
OPENVINO_ASSERT(m_request, "UNet model must be compiled first. Cannot infer non-compiled model");
109-
110-
m_request.set_tensor("sample", sample);
111-
m_request.set_tensor("timestep", timestep);
112-
113-
m_request.infer();
114-
115-
return m_request.get_output_tensor();
100+
OPENVINO_ASSERT(m_impl, "UNet model must be compiled first. Cannot infer non-compiled model");
101+
return m_impl->infer(sample, timestep);
116102
}
117103

118104
} // namespace genai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include "openvino/genai/image_generation/unet2d_condition_model.hpp"
7+
8+
namespace ov {
9+
namespace genai {
10+
11+
class UNet2DConditionModel::UNetInference {
12+
13+
public:
14+
virtual void compile(std::shared_ptr<ov::Model> model, const std::string& device, const ov::AnyMap& properties) = 0;
15+
virtual void set_hidden_states(const std::string& tensor_name, ov::Tensor encoder_hidden_states) = 0;
16+
virtual void set_adapters(AdapterController& adapter_controller, const AdapterConfig& adapters) = 0;
17+
virtual ov::Tensor infer(ov::Tensor sample, ov::Tensor timestep) = 0;
18+
19+
// utility function to resize model given optional dimensions.
20+
static void reshape(std::shared_ptr<ov::Model> model,
21+
std::optional<int> batch_size = {},
22+
std::optional<int> height = {},
23+
std::optional<int> width = {},
24+
std::optional<int> tokenizer_model_max_length = {})
25+
{
26+
std::map<std::string, ov::PartialShape> name_to_shape;
27+
for (auto&& input : model->inputs()) {
28+
std::string input_name = input.get_any_name();
29+
name_to_shape[input_name] = input.get_partial_shape();
30+
if (input_name == "timestep") {
31+
name_to_shape[input_name][0] = 1;
32+
} else if (input_name == "sample") {
33+
if (batch_size) {
34+
name_to_shape[input_name][0] = *batch_size;
35+
}
36+
37+
if (height) {
38+
name_to_shape[input_name][2] = *height;
39+
}
40+
41+
if (width) {
42+
name_to_shape[input_name][3] = *width;
43+
}
44+
} else if (input_name == "time_ids" || input_name == "text_embeds") {
45+
if (batch_size) {
46+
name_to_shape[input_name][0] = *batch_size;
47+
}
48+
} else if (input_name == "encoder_hidden_states") {
49+
if (batch_size) {
50+
name_to_shape[input_name][0] = *batch_size;
51+
}
52+
53+
if (tokenizer_model_max_length) {
54+
name_to_shape[input_name][1] = *tokenizer_model_max_length;
55+
}
56+
} else if (input_name == "timestep_cond") {
57+
if (batch_size) {
58+
name_to_shape[input_name][0] = *batch_size;
59+
}
60+
}
61+
}
62+
63+
model->reshape(name_to_shape);
64+
}
65+
};
66+
67+
} // namespace genai
68+
} // namespace ov

0 commit comments

Comments
 (0)