@@ -105,7 +105,7 @@ def run_image_generation(image_param, num, image_id, pipe, args, iter_data_list,
105
105
for bs_idx , in_text in enumerate (input_text_list ):
106
106
llm_bench_utils .output_file .output_image_input_text (in_text , args , image_id , bs_idx , proc_id )
107
107
start = time .perf_counter ()
108
- res = pipe (input_text_list , ** input_args , num_images_per_prompt = 2 ).images
108
+ res = pipe (input_text_list , ** input_args , num_images_per_prompt = args [ 'batch_size' ] ).images
109
109
end = time .perf_counter ()
110
110
if (args ['mem_consumption' ] == 1 and num == 0 ) or args ['mem_consumption' ] == 2 :
111
111
mem_consumption .end_collect_momory_consumption ()
@@ -152,6 +152,12 @@ def run_image_generation_genai(image_param, num, image_id, pipe, args, iter_data
152
152
out_str += f", guidance_scale={ input_args ['guidance_scale' ]} "
153
153
log .info (f"[{ 'warm-up' if num == 0 else num } ][P{ image_id } ] { out_str } " )
154
154
155
+ if args .get ("static_reshape" , False ) and 'guidance_scale' in input_args :
156
+ reshaped_gs = pipe .get_generation_config ().guidance_scale
157
+ new_gs = input_args ['guidance_scale' ]
158
+ if new_gs != reshaped_gs :
159
+ log .warning (f"image generation pipeline was reshaped with guidance_scale={ reshaped_gs } , but is being passed into generate() as { new_gs } " )
160
+
155
161
result_md5_list = []
156
162
max_rss_mem_consumption = ''
157
163
max_uss_mem_consumption = ''
@@ -212,14 +218,8 @@ def run_image_generation_genai(image_param, num, image_id, pipe, args, iter_data
212
218
213
219
214
220
def run_image_generation_benchmark (model_path , framework , device , args , num_iters , mem_consumption ):
215
- pipe , pretrain_time , use_genai , callback = FW_UTILS [framework ].create_image_gen_model (model_path , device , ** args )
216
- iter_data_list = []
217
- input_image_list = get_image_prompt (args )
218
- if framework == "ov" and not use_genai :
219
- stable_diffusion_hook .new_text_encoder (pipe )
220
- stable_diffusion_hook .new_unet (pipe )
221
- stable_diffusion_hook .new_vae_decoder (pipe )
222
221
222
+ input_image_list = get_image_prompt (args )
223
223
if args ['prompt_index' ] is None :
224
224
prompt_idx_list = [image_id for image_id , input_text in enumerate (input_image_list )]
225
225
image_list = input_image_list
@@ -232,6 +232,25 @@ def run_image_generation_benchmark(model_path, framework, device, args, num_iter
232
232
prompt_idx_list .append (i )
233
233
if len (image_list ) == 0 :
234
234
raise RuntimeError ('==Failure prompts is empty ==' )
235
+
236
+ # If --static_reshape is specified, we need to get width, height, and guidance scale to drop into args
237
+ # as genai's create_image_gen_model implementation will need those to reshape the pipeline before compile().
238
+ if args .get ("static_reshape" , False ):
239
+ static_input_args = collects_input_args (image_list [0 ], args ['model_name' ], args ["num_steps" ],
240
+ args .get ("height" ), args .get ("width" ), image_as_ov_tensor = False )
241
+ args ["height" ] = static_input_args ["height" ]
242
+ args ["width" ] = static_input_args ["width" ]
243
+ if "guidance_scale" in static_input_args :
244
+ args ["guidance_scale" ] = static_input_args ["guidance_scale" ]
245
+
246
+ pipe , pretrain_time , use_genai , callback = FW_UTILS [framework ].create_image_gen_model (model_path , device , ** args )
247
+ iter_data_list = []
248
+
249
+ if framework == "ov" and not use_genai :
250
+ stable_diffusion_hook .new_text_encoder (pipe )
251
+ stable_diffusion_hook .new_unet (pipe )
252
+ stable_diffusion_hook .new_vae_decoder (pipe )
253
+
235
254
log .info (f'Benchmarking iter nums(exclude warm-up): { num_iters } , prompt nums: { len (image_list )} , prompt idx: { prompt_idx_list } ' )
236
255
237
256
if use_genai :
@@ -268,7 +287,7 @@ def run_image_generation_benchmark(model_path, framework, device, args, num_iter
268
287
def get_image_prompt (args ):
269
288
input_image_list = []
270
289
271
- input_key = 'prompt'
290
+ input_key = [ 'prompt' ]
272
291
if args .get ("task" ) == TASK ["inpainting" ] or ((args .get ("media" ) or args .get ("images" )) and args .get ("mask_image" )):
273
292
input_key = ['media' , "mask_image" , "prompt" ]
274
293
elif args .get ("task" ) == TASK ["img2img" ] or args .get ("media" ) or args .get ("images" ):
0 commit comments