Skip to content

Commit 9c98963

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents 5f7b0ff + dbb8568 commit 9c98963

File tree

1 file changed

+11
-24
lines changed
  • demos/paint_your_dreams_demo

1 file changed

+11
-24
lines changed

demos/paint_your_dreams_demo/main.py

+11-24
Original file line numberDiff line numberDiff line change
@@ -75,42 +75,29 @@ def download_models(model_name, safety_checker_model: str) -> None:
7575
image_processor=AutoProcessor.from_pretrained(safety_checker_dir))
7676

7777

78-
async def load_static_pipeline(model_dir: Path, device: str, size: int, pipeline: str) -> genai.Text2ImagePipeline | genai.Image2ImagePipeline:
79-
# NPU requires model static input shape for now
78+
async def create_pipeline(model_dir: Path, device: str, size: int, pipeline: str) -> genai.Text2ImagePipeline | genai.Image2ImagePipeline | genai.InpaintingPipeline:
8079
ov_config = {"CACHE_DIR": "cache"}
8180

82-
scheduler = genai.Scheduler.from_config(model_dir / "scheduler" / "scheduler_config.json")
83-
84-
text_encoder = genai.CLIPTextModel(model_dir / "text_encoder")
85-
text_encoder.reshape(1)
86-
text_encoder.compile(device, **ov_config)
87-
88-
unet = genai.UNet2DConditionModel(model_dir / "unet")
89-
max_position_embeddings = text_encoder.get_config().max_position_embeddings
90-
unet.reshape(1, size, size, max_position_embeddings)
91-
unet.compile(device, **ov_config)
92-
93-
vae = genai.AutoencoderKL(model_dir / "vae_encoder", model_dir / "vae_decoder")
94-
vae.reshape(1, size, size)
95-
vae.compile(device, **ov_config)
96-
9781
if pipeline == "text2image":
98-
ov_pipeline = genai.Text2ImagePipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
82+
ov_pipeline = genai.Text2ImagePipeline(model_dir)
9983
elif pipeline == "image2image":
100-
ov_pipeline = genai.Image2ImagePipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
84+
ov_pipeline = genai.Image2ImagePipeline(model_dir)
10185
elif pipeline == "inpainting":
102-
ov_pipeline = genai.InpaintingPipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
86+
ov_pipeline = genai.InpaintingPipeline(model_dir)
10387
else:
10488
raise ValueError(f"Unknown pipeline: {pipeline}")
10589

90+
ov_pipeline.reshape(1, size, size, ov_pipeline.get_generation_config().guidance_scale)
91+
ov_pipeline.compile(device, config=ov_config)
92+
10693
return ov_pipeline
10794

10895

109-
async def load_pipeline(model_name: str, device: str, size: int, pipeline: str):
96+
async def load_pipeline(model_name: str, device: str, size: int, pipeline: str) -> genai.Text2ImagePipeline | genai.Image2ImagePipeline | genai.InpaintingPipeline:
11097
model_dir = MODEL_DIR / model_name
11198

11299
if (device, pipeline) not in ov_pipelines:
113-
ov_pipelines[(device, pipeline)] = await load_static_pipeline(model_dir, device, size, pipeline)
100+
ov_pipelines[(device, pipeline)] = await create_pipeline(model_dir, device, size, pipeline)
114101

115102
return ov_pipelines[(device, pipeline)]
116103

@@ -121,7 +108,7 @@ async def stop():
121108

122109

123110
progress_bar = None
124-
def progress(step, num_steps, latent):
111+
def progress(step, num_steps, latent) -> bool:
125112
global progress_bar
126113
if progress_bar is None:
127114
progress_bar = tqdm.tqdm(total=num_steps)
@@ -215,7 +202,7 @@ def build_ui(image_size: int) -> gr.Interface:
215202
)
216203
with gr.Row():
217204
with gr.Column():
218-
with gr.Row():
205+
with gr.Row(equal_height=True):
219206
input_image = gr.ImageMask(label="Input image (leave blank for text2image generation)", sources=["webcam", "clipboard", "upload"])
220207
result_img = gr.Image(label="Generated image", elem_id="output_image", format="png")
221208
with gr.Row():

0 commit comments

Comments
 (0)