Skip to content

Commit dbb8568

Browse files
Simplified pipeline creation (#230)
1 parent 678471a commit dbb8568

File tree

1 file changed

+11
-24
lines changed
  • demos/paint_your_dreams_demo

1 file changed

+11
-24
lines changed

demos/paint_your_dreams_demo/main.py

+11-24
Original file line numberDiff line numberDiff line change
@@ -59,42 +59,29 @@ def download_models(model_name, safety_checker_model: str) -> None:
5959
image_processor=AutoProcessor.from_pretrained(safety_checker_dir))
6060

6161

62-
async def load_static_pipeline(model_dir: Path, device: str, size: int, pipeline: str) -> genai.Text2ImagePipeline | genai.Image2ImagePipeline:
63-
# NPU requires model static input shape for now
62+
async def create_pipeline(model_dir: Path, device: str, size: int, pipeline: str) -> genai.Text2ImagePipeline | genai.Image2ImagePipeline | genai.InpaintingPipeline:
6463
ov_config = {"CACHE_DIR": "cache"}
6564

66-
scheduler = genai.Scheduler.from_config(model_dir / "scheduler" / "scheduler_config.json")
67-
68-
text_encoder = genai.CLIPTextModel(model_dir / "text_encoder")
69-
text_encoder.reshape(1)
70-
text_encoder.compile(device, **ov_config)
71-
72-
unet = genai.UNet2DConditionModel(model_dir / "unet")
73-
max_position_embeddings = text_encoder.get_config().max_position_embeddings
74-
unet.reshape(1, size, size, max_position_embeddings)
75-
unet.compile(device, **ov_config)
76-
77-
vae = genai.AutoencoderKL(model_dir / "vae_encoder", model_dir / "vae_decoder")
78-
vae.reshape(1, size, size)
79-
vae.compile(device, **ov_config)
80-
8165
if pipeline == "text2image":
82-
ov_pipeline = genai.Text2ImagePipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
66+
ov_pipeline = genai.Text2ImagePipeline(model_dir)
8367
elif pipeline == "image2image":
84-
ov_pipeline = genai.Image2ImagePipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
68+
ov_pipeline = genai.Image2ImagePipeline(model_dir)
8569
elif pipeline == "inpainting":
86-
ov_pipeline = genai.InpaintingPipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
70+
ov_pipeline = genai.InpaintingPipeline(model_dir)
8771
else:
8872
raise ValueError(f"Unknown pipeline: {pipeline}")
8973

74+
ov_pipeline.reshape(1, size, size, ov_pipeline.get_generation_config().guidance_scale)
75+
ov_pipeline.compile(device, config=ov_config)
76+
9077
return ov_pipeline
9178

9279

93-
async def load_pipeline(model_name: str, device: str, size: int, pipeline: str):
80+
async def load_pipeline(model_name: str, device: str, size: int, pipeline: str) -> genai.Text2ImagePipeline | genai.Image2ImagePipeline | genai.InpaintingPipeline:
9481
model_dir = MODEL_DIR / model_name
9582

9683
if (device, pipeline) not in ov_pipelines:
97-
ov_pipelines[(device, pipeline)] = await load_static_pipeline(model_dir, device, size, pipeline)
84+
ov_pipelines[(device, pipeline)] = await create_pipeline(model_dir, device, size, pipeline)
9885

9986
return ov_pipelines[(device, pipeline)]
10087

@@ -105,7 +92,7 @@ async def stop():
10592

10693

10794
progress_bar = None
108-
def progress(step, num_steps, latent):
95+
def progress(step, num_steps, latent) -> bool:
10996
global progress_bar
11097
if progress_bar is None:
11198
progress_bar = tqdm.tqdm(total=num_steps)
@@ -199,7 +186,7 @@ def build_ui(image_size: int) -> gr.Interface:
199186
)
200187
with gr.Row():
201188
with gr.Column():
202-
with gr.Row():
189+
with gr.Row(equal_height=True):
203190
input_image = gr.ImageMask(label="Input image (leave blank for text2image generation)", sources=["webcam", "clipboard", "upload"])
204191
result_img = gr.Image(label="Generated image", elem_id="output_image", format="png")
205192
with gr.Row():

0 commit comments

Comments
 (0)