Skip to content

Commit 41f1e7b

Browse files
LoRA in Text2ImagePipeline (openvinotoolkit#911)
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
1 parent b11f0d9 commit 41f1e7b

28 files changed

+318
-111
lines changed

.github/workflows/lcm_dreamshaper_cpp.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ jobs:
6767
- name: Run app
6868
run: |
6969
source ${{ env.OV_INSTALL_DIR }}/setupvars.sh
70-
./build/samples/cpp/stable_diffusion/stable_diffusion ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
70+
./build/samples/cpp/text2image/stable_diffusion ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
7171
7272
lcm_dreamshaper_v7_cpp-windows:
7373
runs-on: windows-latest
@@ -118,7 +118,7 @@ jobs:
118118
- name: Run app
119119
run: |
120120
. "${{ env.OV_INSTALL_DIR }}/setupvars.ps1"
121-
./build/samples/cpp/stable_diffusion/Release/lcm_dreamshaper.exe ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
121+
./build/samples/cpp/text2image/Release/lcm_dreamshaper.exe ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
122122
123123
Overall_Status:
124124
name: ci/gha_overall_status_lcm

.github/workflows/stable_diffusion_1_5_cpp.yml

+12-3
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,19 @@ jobs:
6464
source openvino_sd_cpp/bin/activate
6565
optimum-cli export openvino --model dreamlike-art/dreamlike-anime-1.0 --weight-format fp16 --task stable-diffusion models/dreamlike-art-dreamlike-anime-1.0/FP16
6666
67-
- name: Run app
67+
- name: Run main app
6868
run: |
6969
source ${{ env.OV_INSTALL_DIR }}/setupvars.sh
70-
./build/samples/cpp/stable_diffusion/stable_diffusion ./models/dreamlike-art-dreamlike-anime-1.0/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
70+
./build/samples/cpp/text2image/stable_diffusion ./models/dreamlike-art-dreamlike-anime-1.0/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
71+
72+
- name: Download LoRA adapter
73+
run: |
74+
wget -O ./models/soulcard.safetensors https://civitai.com/api/download/models/72591
75+
76+
- name: Run LoRA app
77+
run: |
78+
source ${{ env.OV_INSTALL_DIR }}/setupvars.sh
79+
./build/samples/cpp/text2image/lora_stable_diffusion ./models/dreamlike-art-dreamlike-anime-1.0/FP16 "curly-haired unicorn in the forest, anime, line" ./models/soulcard.safetensors 0.7
7180
7281
stable_diffusion_1_5_cpp-windows:
7382
runs-on: windows-latest
@@ -118,7 +127,7 @@ jobs:
118127
- name: Run app
119128
run: |
120129
. "${{ env.OV_INSTALL_DIR }}/setupvars.ps1"
121-
./build/samples/cpp/stable_diffusion/Release/stable_diffusion.exe ./models/dreamlike-art-dreamlike-anime-1.0/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
130+
./build/samples/cpp/text2image/Release/stable_diffusion.exe ./models/dreamlike-art-dreamlike-anime-1.0/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
122131
123132
Overall_Status:
124133
name: ci/gha_overall_status_stable_diffusion

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ It includes the following pipelines:
3434
6. [multinomial_causal_lm](./samples/cpp/multinomial_causal_lm/README.md)
3535
7. [prompt_lookup_decoding_lm](./samples/cpp/prompt_lookup_decoding_lm/README.md)
3636
8. [speculative_decoding_lm](./samples/cpp/speculative_decoding_lm/README.md)
37-
3. [Stable Diffuison and Latent Consistency Model (with LoRA) C++ image generation pipeline](./samples/cpp/stable_diffusion/README.md)
37+
3. [Stable Diffuison and Latent Consistency Model (with LoRA) C++ image generation pipeline](./samples/cpp/text2image/README.md)
3838

3939
### Requirements
4040

samples/CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ add_subdirectory(cpp/prompt_lookup_decoding_lm)
1313
add_subdirectory(cpp/speculative_decoding_lm)
1414
add_subdirectory(cpp/benchmark_genai)
1515
add_subdirectory(cpp/whisper_speech_recognition)
16-
add_subdirectory(cpp/stable_diffusion)
16+
add_subdirectory(cpp/text2image)
1717

1818
install(FILES requirements.txt DESTINATION samples
1919
COMPONENT cpp_samples_genai)
@@ -26,7 +26,7 @@ install(DIRECTORY
2626
# Don't install prompt_lookup_decoding_lm and speculative_decoding_lm because they don't use openvino_genai library and arent verifyed yet.
2727
# Don't install continuous_batching_accuracy and continuous_batching_benchmark because they depend on json.
2828
cpp/whisper_speech_recognition
29-
cpp/stable_diffusion
29+
cpp/text2image
3030
cpp/lora_greedy_causal_lm
3131
DESTINATION samples/cpp COMPONENT cpp_samples_genai)
3232

@@ -36,6 +36,6 @@ install(DIRECTORY
3636
python/greedy_causal_lm
3737
python/multinomial_causal_lm
3838
python/whisper_speech_recognition
39-
# python/stable_diffusion
39+
# python/text2image
4040
DESTINATION samples/python COMPONENT cpp_samples_genai
4141
USE_SOURCE_PERMISSIONS)

samples/cpp/stable_diffusion/README.md

-48
This file was deleted.
File renamed without changes.

samples/cpp/stable_diffusion/CMakeLists.txt samples/cpp/text2image/CMakeLists.txt

+20-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ find_package(OpenVINOGenAI REQUIRED
88
NO_CMAKE_FIND_ROOT_PATH
99
)
1010

11-
# create executable
11+
# create main sample executable
1212

1313
add_executable(stable_diffusion
1414
${CMAKE_CURRENT_SOURCE_DIR}/main.cpp
@@ -26,3 +26,22 @@ install(TARGETS stable_diffusion
2626
RUNTIME DESTINATION samples_bin/
2727
COMPONENT samples_bin
2828
EXCLUDE_FROM_ALL)
29+
30+
# create LoRA sample executable
31+
32+
add_executable(lora_stable_diffusion
33+
${CMAKE_CURRENT_SOURCE_DIR}/lora.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/imwrite.cpp)
35+
36+
target_include_directories(lora_stable_diffusion PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
37+
target_link_libraries(lora_stable_diffusion PRIVATE openvino::genai)
38+
39+
set_target_properties(lora_stable_diffusion PROPERTIES
40+
COMPILE_PDB_NAME lora_stable_diffusion
41+
# Ensure out of box LC_RPATH on macOS with SIP
42+
INSTALL_RPATH_USE_LINK_PATH ON)
43+
44+
install(TARGETS lora_stable_diffusion
45+
RUNTIME DESTINATION samples_bin/
46+
COMPONENT samples_bin
47+
EXCLUDE_FROM_ALL)

samples/cpp/text2image/README.md

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Text to Image C++ Generation Pipeline
2+
3+
Examples in this folder showcase inference of text to image models like Stable Diffusion 1.5, 2.1, LCM. The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::Text2ImagePipeline` and uses a text prompt as input source.
4+
5+
There are two sample files:
6+
- [`main.cpp`](./main.cpp) demonstrates basic usage of the text to image pipeline
7+
- [`lora.cpp`](./lora.cpp) shows how to apply LoRA adapters to the pipeline
8+
9+
Users can change the sample code and play with the following generation parameters:
10+
11+
- Change width or height of generated image
12+
- Generate multiple images per prompt
13+
- Adjust a number of inference steps
14+
- Play with [guidance scale](https://huggingface.co/spaces/stabilityai/stable-diffusion/discussions/9) (read [more details](https://arxiv.org/abs/2207.12598))
15+
- (SD 1.x, 2.x only) Add negative prompt when guidance scale > 1
16+
- Apply multiple different LoRA adapters and mix them with different blending coefficients
17+
18+
## Download and convert the models and tokenizers
19+
20+
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
21+
22+
It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported.
23+
24+
```sh
25+
pip install --upgrade-strategy eager -r ../../requirements.txt
26+
optimum-cli export openvino --model dreamlike-art/dreamlike-anime-1.0 --task stable-diffusion --weight-format fp16 dreamlike_anime_1_0_ov/FP16
27+
```
28+
29+
## Run
30+
31+
`stable_diffusion ./dreamlike_anime_1_0_ov/FP16 'cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting'`
32+
33+
### Examples
34+
35+
Prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting`
36+
37+
![](./512x512.bmp)
38+
39+
## Supported models
40+
41+
Models can be downloaded from [HiggingFace](https://huggingface.co/models). This sample can run the following list of models, but not limitied to:
42+
43+
- [botp/stable-diffusion-v1-5](https://huggingface.co/botp/stable-diffusion-v1-5)
44+
- [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2)
45+
- [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
46+
- [dreamlike-art/dreamlike-anime-1.0](https://huggingface.co/dreamlike-art/dreamlike-anime-1.0)
47+
- [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7)
48+
49+
## Run with optional LoRA adapters
50+
51+
LoRA adapters can be connected to the pipeline and modify generated images to have certain style, details or quality. Adapters are supported in Safetensors format and can be downloaded from public sources like [Civitai](https://civitai.com) or [HuggingFace](https://huggingface.co/models) or trained by the user. Adapters compatible with a base model should be used only. A weighted blend of multiple adapters can be applied by specifying multple adapter files with corresponding alpha parameters in command line. Check `lora.cpp` source code to learn how to enable adapters and specify them in each `generate` call.
52+
53+
Here is an example how to run the sample with a single adapter. First download adapter file from https://civitai.com/models/67927/soulcard page manually and save it as `soulcard.safetensors`. Or download it from command line:
54+
55+
`wget -O soulcard.safetensors https://civitai.com/api/download/models/72591`
56+
57+
Then run `lora_stable_diffusion` executable:
58+
59+
`./lora_stable_diffusion dreamlike_anime_1_0_ov/FP16 'curly-haired unicorn in the forest, anime, line' soulcard.safetensors 0.7`
60+
61+
The sample generates two images with and without adapters applied using the same prompt:
62+
- `lora.bmp` with adapters applied
63+
- `baseline.bmp` without adapters applied
64+
65+
Check the difference:
66+
67+
With adapter | Without adapter
68+
:---:|:---:
69+
![](./lora.bmp) | ![](./baseline.bmp)
70+
71+
72+
## Note
73+
74+
- Image generated with HuggingFace / Optimum Intel is not the same generated by this C++ sample:
75+
76+
C++ random generation with MT19937 results differ from `numpy.random.randn()` and `diffusers.utils.randn_tensor`. So, it's expected that image generated by Python and C++ versions provide different images, because latent images are initialize differently. Users can implement their own random generator derived from `ov::genai::Generator` and pass it to `Text2ImagePipeline::generate` method.

samples/cpp/text2image/baseline.bmp

1.31 MB
Binary file not shown.
File renamed without changes.
File renamed without changes.

samples/cpp/text2image/lora.bmp

1.31 MB
Binary file not shown.

samples/cpp/text2image/lora.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "openvino/genai/text2image/pipeline.hpp"
5+
6+
#include "imwrite.hpp"
7+
8+
int32_t main(int32_t argc, char* argv[]) try {
9+
OPENVINO_ASSERT(argc >= 3 && (argc - 3) % 2 == 0, "Usage: ", argv[0], " <MODEL_DIR> '<PROMPT>' [<LORA_SAFETENSORS> <ALPHA> ...]]");
10+
11+
const std::string models_path = argv[1], prompt = argv[2];
12+
const std::string device = "CPU"; // GPU, NPU can be used as well
13+
14+
ov::genai::AdapterConfig adapter_config;
15+
// Multiple LoRA adapters applied simultaniously are supported, parse them all and corresponding alphas from cmd parameters:
16+
for(size_t i = 0; i < (argc - 3)/2; ++i) {
17+
ov::genai::Adapter adapter(argv[3 + 2*i]);
18+
float alpha = std::atof(argv[3 + 2*i + 1]);
19+
adapter_config.add(adapter, alpha);
20+
}
21+
22+
// LoRA adapters passed to the constructor will be activated by default in next generates
23+
ov::genai::Text2ImagePipeline pipe(models_path, device, ov::genai::adapters(adapter_config));
24+
25+
std::cout << "Generating image with LoRA adapters applied, resulting image will be in lora.bmp\n";
26+
ov::Tensor image = pipe.generate(prompt,
27+
ov::genai::random_generator(std::make_shared<ov::genai::CppStdGenerator>(42)),
28+
ov::genai::width(512),
29+
ov::genai::height(896),
30+
ov::genai::num_inference_steps(20));
31+
imwrite("lora.bmp", image, true);
32+
33+
std::cout << "Generating image without LoRA adapters applied, resulting image will be in baseline.bmp\n";
34+
image = pipe.generate(prompt,
35+
ov::genai::adapters(), // passing adapters in generate overrides adapters set in the constructor; adapters() means no adapters
36+
ov::genai::random_generator(std::make_shared<ov::genai::CppStdGenerator>(42)),
37+
ov::genai::width(512),
38+
ov::genai::height(896),
39+
ov::genai::num_inference_steps(20));
40+
imwrite("baseline.bmp", image, true);
41+
42+
return EXIT_SUCCESS;
43+
} catch (const std::exception& error) {
44+
try {
45+
std::cerr << error.what() << '\n';
46+
} catch (const std::ios_base::failure&) {}
47+
return EXIT_FAILURE;
48+
} catch (...) {
49+
try {
50+
std::cerr << "Non-exception object thrown\n";
51+
} catch (const std::ios_base::failure&) {}
52+
return EXIT_FAILURE;
53+
}
File renamed without changes.

src/cpp/include/openvino/genai/generation_config.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,6 @@ static constexpr ov::Property<float> presence_penalty{"presence_penalty"};
161161
static constexpr ov::Property<float> frequency_penalty{"frequency_penalty"};
162162
static constexpr ov::Property<size_t> rng_seed{"rng_seed"};
163163

164-
static constexpr AdaptersProperty adapters;
165-
166164
// Predefined Configs
167165
OPENVINO_GENAI_EXPORTS GenerationConfig beam_search();
168166
OPENVINO_GENAI_EXPORTS GenerationConfig greedy();

src/cpp/include/openvino/genai/lora_adapter.hpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ struct OPENVINO_GENAI_EXPORTS AdapterConfig {
9292

9393
class AdaptersProperty : public ov::Property<AdapterConfig> {
9494
public:
95-
constexpr AdaptersProperty() : ov::Property<AdapterConfig>("adapters") {}
95+
inline constexpr static const char* name () { return "adapters"; }
96+
97+
constexpr AdaptersProperty() : ov::Property<AdapterConfig>(name()) {}
9698

9799
inline std::pair<std::string, ov::Any> operator()(const AdapterConfig& config) const {
98100
return ov::Property<AdapterConfig>::operator()(config);
@@ -154,6 +156,9 @@ class AdaptersProperty : public ov::Property<AdapterConfig> {
154156
};
155157

156158

159+
static constexpr AdaptersProperty adapters;
160+
161+
157162
class OPENVINO_GENAI_EXPORTS AdapterController {
158163

159164
std::shared_ptr<AdapterControllerImpl> m_pimpl;
@@ -165,15 +170,12 @@ class OPENVINO_GENAI_EXPORTS AdapterController {
165170

166171
AdapterController(std::shared_ptr<ov::Model> model, const AdapterConfig& config, const std::string& prefix, std::string device = "");
167172

168-
// Call it every time when adapter config is changed; if adapter is configured as a static one, this call is not required
169-
void apply(ov::InferRequest& request, const AdapterConfig& config);
173+
// Apply adapters configured in the current config set last time, or set and use new config given as optional `config` argument
174+
void apply(ov::InferRequest& request, const std::optional<AdapterConfig>& config = std::nullopt);
170175

171176
// the next call of apply will set all adapter tensors regardless of config change, use this method if full state.reset is called for the controlled model
172177
void force_full_apply(bool full_apply = true);
173178

174-
// Apply the same config that was used last time (in initialization or in previous call to apply).
175-
void apply(ov::InferRequest& request);
176-
177179
operator bool() const {
178180
return bool(m_pimpl);
179181
}

src/cpp/include/openvino/genai/text2image/clip_text_model.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "openvino/genai/visibility.hpp"
99
#include "openvino/genai/tokenizer.hpp"
10+
#include "openvino/genai/lora_adapter.hpp"
1011

1112
#include "openvino/core/any.hpp"
1213
#include "openvino/runtime/tensor.hpp"
@@ -53,10 +54,13 @@ class OPENVINO_GENAI_EXPORTS CLIPTextModel {
5354
return compile(device, ov::AnyMap{std::forward<Properties>(properties)...});
5455
}
5556

57+
void set_adapters(const AdapterConfig& adapters);
58+
5659
ov::Tensor infer(const std::string& pos_prompt, const std::string& neg_prompt, bool do_classifier_free_guidance);
5760

5861
private:
5962
Config m_config;
63+
AdapterController m_adapter_controller;
6064
ov::InferRequest m_request;
6165
std::shared_ptr<ov::Model> m_model;
6266

src/cpp/include/openvino/genai/text2image/pipeline.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "openvino/genai/visibility.hpp"
1515

16+
#include "openvino/genai/lora_adapter.hpp"
1617
#include "openvino/genai/text2image/clip_text_model.hpp"
1718
#include "openvino/genai/text2image/unet2d_condition_model.hpp"
1819
#include "openvino/genai/text2image/autoencoder_kl.hpp"
@@ -81,6 +82,8 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline {
8182
int64_t width = -1;
8283
size_t num_inference_steps = 50;
8384

85+
AdapterConfig adapters;
86+
8487
void update_generation_config(const ov::AnyMap& config_map);
8588

8689
// checks whether is config is valid
@@ -96,6 +99,13 @@ class OPENVINO_GENAI_EXPORTS Text2ImagePipeline {
9699

97100
Text2ImagePipeline(const std::string& root_dir, const std::string& device, const ov::AnyMap& properties = {});
98101

102+
template <typename... Properties,
103+
typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
104+
Text2ImagePipeline(const std::string& root_dir,
105+
const std::string& device,
106+
Properties&&... properties)
107+
: Text2ImagePipeline(root_dir, device, ov::AnyMap{std::forward<Properties>(properties)...}) { }
108+
99109
// creates either LCM or SD pipeline from building blocks
100110
static Text2ImagePipeline stable_diffusion(
101111
const std::shared_ptr<Scheduler>& scheduler_type,

0 commit comments

Comments
 (0)