@@ -75,42 +75,29 @@ def download_models(model_name, safety_checker_model: str) -> None:
75
75
image_processor = AutoProcessor .from_pretrained (safety_checker_dir ))
76
76
77
77
78
- async def load_static_pipeline (model_dir : Path , device : str , size : int , pipeline : str ) -> genai .Text2ImagePipeline | genai .Image2ImagePipeline :
79
- # NPU requires model static input shape for now
78
+ async def create_pipeline (model_dir : Path , device : str , size : int , pipeline : str ) -> genai .Text2ImagePipeline | genai .Image2ImagePipeline | genai .InpaintingPipeline :
80
79
ov_config = {"CACHE_DIR" : "cache" }
81
80
82
- scheduler = genai .Scheduler .from_config (model_dir / "scheduler" / "scheduler_config.json" )
83
-
84
- text_encoder = genai .CLIPTextModel (model_dir / "text_encoder" )
85
- text_encoder .reshape (1 )
86
- text_encoder .compile (device , ** ov_config )
87
-
88
- unet = genai .UNet2DConditionModel (model_dir / "unet" )
89
- max_position_embeddings = text_encoder .get_config ().max_position_embeddings
90
- unet .reshape (1 , size , size , max_position_embeddings )
91
- unet .compile (device , ** ov_config )
92
-
93
- vae = genai .AutoencoderKL (model_dir / "vae_encoder" , model_dir / "vae_decoder" )
94
- vae .reshape (1 , size , size )
95
- vae .compile (device , ** ov_config )
96
-
97
81
if pipeline == "text2image" :
98
- ov_pipeline = genai .Text2ImagePipeline . latent_consistency_model ( scheduler , text_encoder , unet , vae )
82
+ ov_pipeline = genai .Text2ImagePipeline ( model_dir )
99
83
elif pipeline == "image2image" :
100
- ov_pipeline = genai .Image2ImagePipeline . latent_consistency_model ( scheduler , text_encoder , unet , vae )
84
+ ov_pipeline = genai .Image2ImagePipeline ( model_dir )
101
85
elif pipeline == "inpainting" :
102
- ov_pipeline = genai .InpaintingPipeline . latent_consistency_model ( scheduler , text_encoder , unet , vae )
86
+ ov_pipeline = genai .InpaintingPipeline ( model_dir )
103
87
else :
104
88
raise ValueError (f"Unknown pipeline: { pipeline } " )
105
89
90
+ ov_pipeline .reshape (1 , size , size , ov_pipeline .get_generation_config ().guidance_scale )
91
+ ov_pipeline .compile (device , config = ov_config )
92
+
106
93
return ov_pipeline
107
94
108
95
109
- async def load_pipeline (model_name : str , device : str , size : int , pipeline : str ):
96
+ async def load_pipeline (model_name : str , device : str , size : int , pipeline : str ) -> genai . Text2ImagePipeline | genai . Image2ImagePipeline | genai . InpaintingPipeline :
110
97
model_dir = MODEL_DIR / model_name
111
98
112
99
if (device , pipeline ) not in ov_pipelines :
113
- ov_pipelines [(device , pipeline )] = await load_static_pipeline (model_dir , device , size , pipeline )
100
+ ov_pipelines [(device , pipeline )] = await create_pipeline (model_dir , device , size , pipeline )
114
101
115
102
return ov_pipelines [(device , pipeline )]
116
103
@@ -121,7 +108,7 @@ async def stop():
121
108
122
109
123
110
progress_bar = None
124
- def progress (step , num_steps , latent ):
111
+ def progress (step , num_steps , latent ) -> bool :
125
112
global progress_bar
126
113
if progress_bar is None :
127
114
progress_bar = tqdm .tqdm (total = num_steps )
@@ -215,7 +202,7 @@ def build_ui(image_size: int) -> gr.Interface:
215
202
)
216
203
with gr .Row ():
217
204
with gr .Column ():
218
- with gr .Row ():
205
+ with gr .Row (equal_height = True ):
219
206
input_image = gr .ImageMask (label = "Input image (leave blank for text2image generation)" , sources = ["webcam" , "clipboard" , "upload" ])
220
207
result_img = gr .Image (label = "Generated image" , elem_id = "output_image" , format = "png" )
221
208
with gr .Row ():
0 commit comments