@@ -216,6 +216,112 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list,
216
216
bench_hook .clear_time_infer_list ()
217
217
218
218
219
+ def run_text_generation_genai_with_stream (input_text , num , model , tokenizer , args , iter_data_list , md5_list , prompt_index , streamer , model_precision , proc_id ):
220
+ set_seed (args ['seed' ])
221
+ input_text_list = [input_text ] * args ['batch_size' ]
222
+ if args ["output_dir" ] is not None and num == 0 :
223
+ for bs_index , in_text in enumerate (input_text_list ):
224
+ llm_bench_utils .output_file .output_input_text (in_text , args , model_precision , prompt_index , bs_index , proc_id )
225
+ pt_inputs = tokenizer (input_text_list , return_tensors = "pt" )
226
+ input_token_size = pt_inputs .input_ids .shape [1 ]
227
+ pipe_tokenizer = model .get_tokenizer ()
228
+ tok_encode_start = time .perf_counter ()
229
+ input_data = pipe_tokenizer .encode (input_text_list )
230
+ tok_encode_end = time .perf_counter ()
231
+ tok_encode_time = (tok_encode_end - tok_encode_start ) * 1000
232
+ if args ['batch_size' ] > 1 :
233
+ out_str = '[warm-up]' if num == 0 else '[{}]' .format (num )
234
+ out_str += " Batch_size={}, " .format (args ['batch_size' ])
235
+ out_str += 'all input token size after padding: {} * {}, ' .format (input_token_size , args ['batch_size' ])
236
+ if args ['infer_count' ] is not None :
237
+ out_str += 'all max_output_token_size: {} * {}' .format (args ['infer_count' ], args ['batch_size' ])
238
+ log .info (out_str )
239
+
240
+ max_rss_mem_consumption = ''
241
+ max_uss_mem_consumption = ''
242
+ max_shared_mem_consumption = ''
243
+ if (args ['mem_consumption' ] == 1 and num == 0 ) or args ['mem_consumption' ] == 2 :
244
+ mem_consumption .start_collect_memory_consumption ()
245
+ max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args ['infer_count' ] is None else args ['infer_count' ]
246
+ streamer .reset ()
247
+ start = time .perf_counter ()
248
+ generated_tokens = model .generate (input_data , max_new_tokens = max_gen_tokens , num_beams = args ["num_beams" ], streamer = streamer ).tokens
249
+ end = time .perf_counter ()
250
+ if (args ['mem_consumption' ] == 1 and num == 0 ) or args ['mem_consumption' ] == 2 :
251
+ mem_consumption .end_collect_momory_consumption ()
252
+ max_rss_mem_consumption , max_shared_mem_consumption , max_uss_mem_consumption = mem_consumption .get_max_memory_consumption ()
253
+ mem_consumption .clear_max_memory_consumption ()
254
+
255
+ generation_time = end - start
256
+ tok_decode_start = time .perf_counter ()
257
+ generated_text = pipe_tokenizer .decode (generated_tokens )
258
+ tok_decode_end = time .perf_counter ()
259
+ tok_decode_time = (tok_decode_end - tok_decode_start ) * 1000
260
+ # Only text_gen need to minus length of input_data, because generated_text may include input_text
261
+ num_tokens = 0
262
+ result_md5_list = []
263
+ for bs_idx in range (args ['batch_size' ]):
264
+ generated_text_len = len (generated_tokens [bs_idx ])
265
+ num_tokens += generated_text_len
266
+ if generated_text_len > max_gen_tokens :
267
+ log .error ('Output token size is over max output token size!' )
268
+ result_text = generated_text [bs_idx ]
269
+ if args ["output_dir" ] is not None :
270
+ llm_bench_utils .output_file .output_gen_text (result_text , args , model_precision , prompt_index , num , bs_idx , proc_id )
271
+ result_md5_list .append (hashlib .new ("md5" , result_text .encode (), usedforsecurity = False ).hexdigest ())
272
+ if len (md5_list [num ]) == 0 :
273
+ md5_list [num ] = {prompt_index : result_md5_list }
274
+ else :
275
+ md5_list [num ][prompt_index ] = result_md5_list
276
+ per_token_time = generation_time * 1000 / (num_tokens / args ['batch_size' ])
277
+ tm_list = streamer .get_time_list ()
278
+ log .debug ('latency of all tokens:' )
279
+ [log .debug ('[{}]{:.4f}' .format (idx , tm )) for idx , tm in enumerate (tm_list )]
280
+ iter_data = gen_iterate_data (
281
+ num ,
282
+ input_token_size * args ['batch_size' ],
283
+ len (tm_list ),
284
+ num_tokens ,
285
+ generation_time ,
286
+ per_token_time ,
287
+ result_md5_list ,
288
+ max_rss_mem = max_rss_mem_consumption ,
289
+ max_shared_mem = max_shared_mem_consumption ,
290
+ max_uss_mem = max_uss_mem_consumption ,
291
+ prompt_idx = prompt_index ,
292
+ tokenization_time = (tok_encode_time , tok_decode_time )
293
+ )
294
+ iter_data_list .append (iter_data )
295
+ llm_bench_utils .metrics_print .print_metrics (
296
+ num ,
297
+ iter_data ,
298
+ tm_list ,
299
+ [],
300
+ warm_up = (num == 0 ),
301
+ max_rss_mem = max_rss_mem_consumption ,
302
+ max_shared_mem = max_shared_mem_consumption ,
303
+ max_uss_mem = max_uss_mem_consumption ,
304
+ tokenization_time = (tok_encode_time , tok_decode_time ),
305
+ batch_size = args ['batch_size' ]
306
+ )
307
+ if num > 0 :
308
+ prev_md5 = md5_list [num - 1 ][prompt_index ]
309
+ if result_md5_list != prev_md5 :
310
+ log .warning (f"[{ num } ] Prompt[{ prompt_index } ]'s md5 { result_md5_list } "
311
+ f"is different from md5 of the { num - 1 } iteration { prev_md5 } " )
312
+ llm_bench_utils .metrics_print .print_generated (num , warm_up = (num == 0 ), generated = generated_text [0 ])
313
+ if num == 1 :
314
+ # if the device is CPU, throw exception
315
+ if args ['devices' ].lower ().startswith ('cpu' ) is True :
316
+ assert (result_md5_list == prev_md5 )
317
+ else :
318
+ # throw exception
319
+ assert (result_md5_list == prev_md5 )
320
+ else :
321
+ llm_bench_utils .metrics_print .print_generated (num , warm_up = (num == 0 ), generated = generated_text [0 ])
322
+ streamer .reset ()
323
+
324
+
219
325
def run_text_generation_genai (input_text , num , model , tokenizer , args , iter_data_list , md5_list , prompt_index , streamer , model_precision , proc_id ):
220
326
set_seed (args ['seed' ])
221
327
input_text_list = [input_text ] * args ['batch_size' ]
@@ -341,7 +447,12 @@ def run_text_generation_benchmark(model_path, framework, device, args, num_iters
341
447
f'prompt nums: { len (text_list )} , prompt idx: { prompt_idx_list } ' )
342
448
343
449
# if num_iters == 0, just output warm-up data
344
- text_gen_fn = run_text_generation if not use_genai else run_text_generation_genai
450
+ if not use_genai :
451
+ text_gen_fn = run_text_generation
452
+ elif bench_hook is not None :
453
+ text_gen_fn = run_text_generation_genai_with_stream
454
+ else :
455
+ text_gen_fn = run_text_generation_genai
345
456
proc_id = os .getpid ()
346
457
if args ['subsequent' ] is False :
347
458
for num in range (num_iters + 1 ):
0 commit comments