30
30
31
31
safety_checker : Optional [Pipeline ] = None
32
32
33
- ov_pipelines_t2i = {}
34
- ov_pipelines_i2i = {}
33
+ ov_pipelines = {}
35
34
36
35
stop_generating : bool = True
37
36
hf_model_name : Optional [str ] = None
@@ -83,6 +82,8 @@ async def load_static_pipeline(model_dir: Path, device: str, size: int, pipelin
83
82
ov_pipeline = genai .Text2ImagePipeline .latent_consistency_model (scheduler , text_encoder , unet , vae )
84
83
elif pipeline == "image2image" :
85
84
ov_pipeline = genai .Image2ImagePipeline .latent_consistency_model (scheduler , text_encoder , unet , vae )
85
+ elif pipeline == "inpainting" :
86
+ ov_pipeline = genai .InpaintingPipeline .latent_consistency_model (scheduler , text_encoder , unet , vae )
86
87
else :
87
88
raise ValueError (f"Unknown pipeline: { pipeline } " )
88
89
@@ -92,17 +93,10 @@ async def load_static_pipeline(model_dir: Path, device: str, size: int, pipelin
92
93
async def load_pipeline (model_name : str , device : str , size : int , pipeline : str ):
93
94
model_dir = MODEL_DIR / model_name
94
95
95
- if pipeline == "text2image" :
96
- if device not in ov_pipelines_t2i :
97
- ov_pipelines_t2i [device ] = await load_static_pipeline (model_dir , device , size , pipeline )
98
-
99
- return ov_pipelines_t2i [device ]
96
+ if (device , pipeline ) not in ov_pipelines :
97
+ ov_pipelines [(device , pipeline )] = await load_static_pipeline (model_dir , device , size , pipeline )
100
98
101
- if pipeline == "image2image" :
102
- if device not in ov_pipelines_i2i :
103
- ov_pipelines_i2i [device ] = await load_static_pipeline (model_dir , device , size , pipeline )
104
-
105
- return ov_pipelines_i2i [device ]
99
+ return ov_pipelines [(device , pipeline )]
106
100
107
101
108
102
async def stop ():
@@ -124,27 +118,42 @@ def progress(step, num_steps, latent):
124
118
return False
125
119
126
120
127
- async def generate_images (input_image : np .ndarray , prompt : str , seed : int , guidance_scale : float , num_inference_steps : int ,
121
+ async def generate_images (input_image_mask : np .ndarray , prompt : str , seed : int , guidance_scale : float , num_inference_steps : int ,
128
122
strength : float , randomize_seed : bool , device : str , endless_generation : bool , size : int ) -> tuple [np .ndarray , float ]:
129
123
global stop_generating
130
124
stop_generating = not endless_generation
131
125
126
+ input_image = input_image_mask ["background" ][:, :, :3 ]
127
+ image_mask = input_image_mask ["layers" ][0 ][:, :, 3 :]
128
+
129
+ # ensure image is square
130
+ input_image = utils .crop_center (input_image )
131
+ input_image = cv2 .resize (input_image , (size , size ))
132
+ image_mask = cv2 .resize (image_mask , (size , size ), interpolation = cv2 .INTER_NEAREST )
133
+ image_mask = cv2 .cvtColor (image_mask , cv2 .COLOR_GRAY2BGR )
134
+
132
135
while True :
133
136
if randomize_seed :
134
137
seed = random .randint (0 , MAX_SEED )
135
138
136
139
start_time = time .time ()
137
- if input_image is None :
140
+ if input_image .any ():
141
+
142
+ # inpainting pipeline
143
+ if image_mask .any ():
144
+ ov_pipeline = await load_pipeline (hf_model_name , device , size , "inpainting" )
145
+ result = ov_pipeline .generate (prompt = prompt , image = ov .Tensor (input_image [None ]), mask_image = ov .Tensor (image_mask [None ]), num_inference_steps = num_inference_steps ,
146
+ width = size , height = size , guidance_scale = guidance_scale , strength = 1.0 - strength , rng_seed = seed , callback = progress ).data [0 ]
147
+ # image2image pipeline
148
+ else :
149
+ ov_pipeline = await load_pipeline (hf_model_name , device , size ,"image2image" )
150
+ result = ov_pipeline .generate (prompt = prompt , image = ov .Tensor (input_image [None ]), num_inference_steps = num_inference_steps , width = size , height = size ,
151
+ guidance_scale = guidance_scale , strength = 1.0 - strength , rng_seed = seed , callback = progress ).data [0 ]
152
+ # text2image pipeline
153
+ else :
138
154
ov_pipeline = await load_pipeline (hf_model_name , device , size , "text2image" )
139
155
result = ov_pipeline .generate (prompt = prompt , num_inference_steps = num_inference_steps , width = size , height = size ,
140
156
guidance_scale = guidance_scale , rng_seed = seed , callback = progress ).data [0 ]
141
- else :
142
- ov_pipeline = await load_pipeline (hf_model_name , device , size ,"image2image" )
143
- # ensure image is square
144
- input_image = utils .crop_center (input_image )
145
- input_image = cv2 .resize (input_image , (size , size ))
146
- result = ov_pipeline .generate (prompt = prompt , image = ov .Tensor (input_image [None ]), num_inference_steps = num_inference_steps , width = size , height = size ,
147
- guidance_scale = guidance_scale , strength = 1.0 - strength , rng_seed = seed , callback = progress ).data [0 ]
148
157
end_time = time .time ()
149
158
150
159
label = safety_checker (Image .fromarray (result ), top_k = 1 )
@@ -191,7 +200,7 @@ def build_ui(image_size: int) -> gr.Interface:
191
200
with gr .Row ():
192
201
with gr .Column ():
193
202
with gr .Row ():
194
- input_image = gr .Image (label = "Input image (leave blank for text2image generation)" , sources = ["webcam" , "clipboard" , "upload" ])
203
+ input_image = gr .ImageMask (label = "Input image (leave blank for text2image generation)" , sources = ["webcam" , "clipboard" , "upload" ])
195
204
result_img = gr .Image (label = "Generated image" , elem_id = "output_image" , format = "png" )
196
205
with gr .Row ():
197
206
result_time_label = gr .Text ("" , label = "Inference time" , type = "text" )
0 commit comments