Skip to content

Commit 678471a

Browse files
Added inpainting pipeline for paint your dreams demo (#229)
1 parent 72dd3a1 commit 678471a

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

demos/paint_your_dreams_demo/main.py

+31-22
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030

3131
safety_checker: Optional[Pipeline] = None
3232

33-
ov_pipelines_t2i = {}
34-
ov_pipelines_i2i = {}
33+
ov_pipelines = {}
3534

3635
stop_generating: bool = True
3736
hf_model_name: Optional[str] = None
@@ -83,6 +82,8 @@ async def load_static_pipeline(model_dir: Path, device: str, size: int, pipelin
8382
ov_pipeline = genai.Text2ImagePipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
8483
elif pipeline == "image2image":
8584
ov_pipeline = genai.Image2ImagePipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
85+
elif pipeline == "inpainting":
86+
ov_pipeline = genai.InpaintingPipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)
8687
else:
8788
raise ValueError(f"Unknown pipeline: {pipeline}")
8889

@@ -92,17 +93,10 @@ async def load_static_pipeline(model_dir: Path, device: str, size: int, pipelin
9293
async def load_pipeline(model_name: str, device: str, size: int, pipeline: str):
9394
model_dir = MODEL_DIR / model_name
9495

95-
if pipeline == "text2image":
96-
if device not in ov_pipelines_t2i:
97-
ov_pipelines_t2i[device] = await load_static_pipeline(model_dir, device, size, pipeline)
98-
99-
return ov_pipelines_t2i[device]
96+
if (device, pipeline) not in ov_pipelines:
97+
ov_pipelines[(device, pipeline)] = await load_static_pipeline(model_dir, device, size, pipeline)
10098

101-
if pipeline == "image2image":
102-
if device not in ov_pipelines_i2i:
103-
ov_pipelines_i2i[device] = await load_static_pipeline(model_dir, device, size, pipeline)
104-
105-
return ov_pipelines_i2i[device]
99+
return ov_pipelines[(device, pipeline)]
106100

107101

108102
async def stop():
@@ -124,27 +118,42 @@ def progress(step, num_steps, latent):
124118
return False
125119

126120

127-
async def generate_images(input_image: np.ndarray, prompt: str, seed: int, guidance_scale: float, num_inference_steps: int,
121+
async def generate_images(input_image_mask: np.ndarray, prompt: str, seed: int, guidance_scale: float, num_inference_steps: int,
128122
strength: float, randomize_seed: bool, device: str, endless_generation: bool, size: int) -> tuple[np.ndarray, float]:
129123
global stop_generating
130124
stop_generating = not endless_generation
131125

126+
input_image = input_image_mask["background"][:, :, :3]
127+
image_mask = input_image_mask["layers"][0][:, :, 3:]
128+
129+
# ensure image is square
130+
input_image = utils.crop_center(input_image)
131+
input_image = cv2.resize(input_image, (size, size))
132+
image_mask = cv2.resize(image_mask, (size, size), interpolation=cv2.INTER_NEAREST)
133+
image_mask = cv2.cvtColor(image_mask, cv2.COLOR_GRAY2BGR)
134+
132135
while True:
133136
if randomize_seed:
134137
seed = random.randint(0, MAX_SEED)
135138

136139
start_time = time.time()
137-
if input_image is None:
140+
if input_image.any():
141+
142+
# inpainting pipeline
143+
if image_mask.any():
144+
ov_pipeline = await load_pipeline(hf_model_name, device, size, "inpainting")
145+
result = ov_pipeline.generate(prompt=prompt, image=ov.Tensor(input_image[None]), mask_image=ov.Tensor(image_mask[None]), num_inference_steps=num_inference_steps,
146+
width=size, height=size, guidance_scale=guidance_scale, strength=1.0 - strength, rng_seed=seed, callback=progress).data[0]
147+
# image2image pipeline
148+
else:
149+
ov_pipeline = await load_pipeline(hf_model_name, device, size,"image2image")
150+
result = ov_pipeline.generate(prompt=prompt, image=ov.Tensor(input_image[None]), num_inference_steps=num_inference_steps, width=size, height=size,
151+
guidance_scale=guidance_scale, strength=1.0 - strength, rng_seed=seed, callback=progress).data[0]
152+
# text2image pipeline
153+
else:
138154
ov_pipeline = await load_pipeline(hf_model_name, device, size, "text2image")
139155
result = ov_pipeline.generate(prompt=prompt, num_inference_steps=num_inference_steps, width=size, height=size,
140156
guidance_scale=guidance_scale, rng_seed=seed, callback=progress).data[0]
141-
else:
142-
ov_pipeline = await load_pipeline(hf_model_name, device, size,"image2image")
143-
# ensure image is square
144-
input_image = utils.crop_center(input_image)
145-
input_image = cv2.resize(input_image, (size, size))
146-
result = ov_pipeline.generate(prompt=prompt, image=ov.Tensor(input_image[None]), num_inference_steps=num_inference_steps, width=size, height=size,
147-
guidance_scale=guidance_scale, strength=1.0 - strength, rng_seed=seed, callback=progress).data[0]
148157
end_time = time.time()
149158

150159
label = safety_checker(Image.fromarray(result), top_k=1)
@@ -191,7 +200,7 @@ def build_ui(image_size: int) -> gr.Interface:
191200
with gr.Row():
192201
with gr.Column():
193202
with gr.Row():
194-
input_image = gr.Image(label="Input image (leave blank for text2image generation)", sources=["webcam", "clipboard", "upload"])
203+
input_image = gr.ImageMask(label="Input image (leave blank for text2image generation)", sources=["webcam", "clipboard", "upload"])
195204
result_img = gr.Image(label="Generated image", elem_id="output_image", format="png")
196205
with gr.Row():
197206
result_time_label = gr.Text("", label="Inference time", type="text")

demos/paint_your_dreams_demo/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ openvino-genai==2025.1.0.0.dev20250305
77
optimum-intel==1.22.0
88
optimum==1.24.0
99
onnx==1.17.0
10-
huggingface-hub==0.27.0
10+
huggingface-hub==0.29.3
1111
diffusers==0.32.1
1212
transformers==4.48.3
1313
torch==2.5.1
1414
accelerate==1.2.1
1515
pillow==11.1.0
1616
opencv-python==4.10.0.84
1717
numpy==2.1.3
18-
gradio==5.12.0
18+
gradio==5.22.0
1919
tqdm==4.67.1

0 commit comments

Comments
 (0)