@@ -59,42 +59,29 @@ def download_models(model_name, safety_checker_model: str) -> None:
59
59
image_processor = AutoProcessor .from_pretrained (safety_checker_dir ))
60
60
61
61
62
- async def load_static_pipeline (model_dir : Path , device : str , size : int , pipeline : str ) -> genai .Text2ImagePipeline | genai .Image2ImagePipeline :
63
- # NPU requires model static input shape for now
62
+ async def create_pipeline (model_dir : Path , device : str , size : int , pipeline : str ) -> genai .Text2ImagePipeline | genai .Image2ImagePipeline | genai .InpaintingPipeline :
64
63
ov_config = {"CACHE_DIR" : "cache" }
65
64
66
- scheduler = genai .Scheduler .from_config (model_dir / "scheduler" / "scheduler_config.json" )
67
-
68
- text_encoder = genai .CLIPTextModel (model_dir / "text_encoder" )
69
- text_encoder .reshape (1 )
70
- text_encoder .compile (device , ** ov_config )
71
-
72
- unet = genai .UNet2DConditionModel (model_dir / "unet" )
73
- max_position_embeddings = text_encoder .get_config ().max_position_embeddings
74
- unet .reshape (1 , size , size , max_position_embeddings )
75
- unet .compile (device , ** ov_config )
76
-
77
- vae = genai .AutoencoderKL (model_dir / "vae_encoder" , model_dir / "vae_decoder" )
78
- vae .reshape (1 , size , size )
79
- vae .compile (device , ** ov_config )
80
-
81
65
if pipeline == "text2image" :
82
- ov_pipeline = genai .Text2ImagePipeline . latent_consistency_model ( scheduler , text_encoder , unet , vae )
66
+ ov_pipeline = genai .Text2ImagePipeline ( model_dir )
83
67
elif pipeline == "image2image" :
84
- ov_pipeline = genai .Image2ImagePipeline . latent_consistency_model ( scheduler , text_encoder , unet , vae )
68
+ ov_pipeline = genai .Image2ImagePipeline ( model_dir )
85
69
elif pipeline == "inpainting" :
86
- ov_pipeline = genai .InpaintingPipeline . latent_consistency_model ( scheduler , text_encoder , unet , vae )
70
+ ov_pipeline = genai .InpaintingPipeline ( model_dir )
87
71
else :
88
72
raise ValueError (f"Unknown pipeline: { pipeline } " )
89
73
74
+ ov_pipeline .reshape (1 , size , size , ov_pipeline .get_generation_config ().guidance_scale )
75
+ ov_pipeline .compile (device , config = ov_config )
76
+
90
77
return ov_pipeline
91
78
92
79
93
- async def load_pipeline (model_name : str , device : str , size : int , pipeline : str ):
80
+ async def load_pipeline (model_name : str , device : str , size : int , pipeline : str ) -> genai . Text2ImagePipeline | genai . Image2ImagePipeline | genai . InpaintingPipeline :
94
81
model_dir = MODEL_DIR / model_name
95
82
96
83
if (device , pipeline ) not in ov_pipelines :
97
- ov_pipelines [(device , pipeline )] = await load_static_pipeline (model_dir , device , size , pipeline )
84
+ ov_pipelines [(device , pipeline )] = await create_pipeline (model_dir , device , size , pipeline )
98
85
99
86
return ov_pipelines [(device , pipeline )]
100
87
@@ -105,7 +92,7 @@ async def stop():
105
92
106
93
107
94
progress_bar = None
108
- def progress (step , num_steps , latent ):
95
+ def progress (step , num_steps , latent ) -> bool :
109
96
global progress_bar
110
97
if progress_bar is None :
111
98
progress_bar = tqdm .tqdm (total = num_steps )
@@ -199,7 +186,7 @@ def build_ui(image_size: int) -> gr.Interface:
199
186
)
200
187
with gr .Row ():
201
188
with gr .Column ():
202
- with gr .Row ():
189
+ with gr .Row (equal_height = True ):
203
190
input_image = gr .ImageMask (label = "Input image (leave blank for text2image generation)" , sources = ["webcam" , "clipboard" , "upload" ])
204
191
result_img = gr .Image (label = "Generated image" , elem_id = "output_image" , format = "png" )
205
192
with gr .Row ():
0 commit comments