@@ -319,6 +319,110 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data
319
319
llm_bench_utils .metrics_print .print_generated (num , warm_up = (num == 0 ), generated = generated_text [0 ])
320
320
321
321
322
+ def run_text_generation_genai_with_stream (input_text , num , model , tokenizer , args , iter_data_list , md5_list , prompt_index , streamer , model_precision , proc_id ):
323
+ set_seed (args ['seed' ])
324
+ input_text_list = [input_text ] * args ['batch_size' ]
325
+ if args ["output_dir" ] is not None and num == 0 :
326
+ for bs_index , in_text in enumerate (input_text_list ):
327
+ llm_bench_utils .output_file .output_input_text (in_text , args , model_precision , prompt_index , bs_index , proc_id )
328
+ pt_inputs = tokenizer (input_text_list , return_tensors = "pt" )
329
+ input_token_size = pt_inputs .input_ids .shape [1 ]
330
+ pipe_tokenizer = model .get_tokenizer ()
331
+ tok_encode_start = time .perf_counter ()
332
+ input_data = pipe_tokenizer .encode (input_text_list )
333
+ tok_encode_end = time .perf_counter ()
334
+ tok_encode_time = (tok_encode_end - tok_encode_start ) * 1000
335
+ if args ['batch_size' ] > 1 :
336
+ out_str = '[warm-up]' if num == 0 else '[{}]' .format (num )
337
+ out_str += " Batch_size={}, " .format (args ['batch_size' ])
338
+ out_str += 'all input token size after padding: {} * {}, ' .format (input_token_size , args ['batch_size' ])
339
+ if args ['infer_count' ] is not None :
340
+ out_str += 'all max_output_token_size: {} * {}' .format (args ['infer_count' ], args ['batch_size' ])
341
+ log .info (out_str )
342
+ max_rss_mem_consumption = ''
343
+ max_uss_mem_consumption = ''
344
+ max_shared_mem_consumption = ''
345
+ if (args ['mem_consumption' ] == 1 and num == 0 ) or args ['mem_consumption' ] == 2 :
346
+ mem_consumption .start_collect_memory_consumption ()
347
+ max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args ['infer_count' ] is None else args ['infer_count' ]
348
+ streamer .reset ()
349
+ start = time .perf_counter ()
350
+ generated_tokens = model .generate (input_data , max_new_tokens = max_gen_tokens , num_beams = args ["num_beams" ], streamer = streamer ).tokens
351
+ end = time .perf_counter ()
352
+ if (args ['mem_consumption' ] == 1 and num == 0 ) or args ['mem_consumption' ] == 2 :
353
+ mem_consumption .end_collect_momory_consumption ()
354
+ max_rss_mem_consumption , max_shared_mem_consumption , max_uss_mem_consumption = mem_consumption .get_max_memory_consumption ()
355
+ mem_consumption .clear_max_memory_consumption ()
356
+ generation_time = end - start
357
+ tok_decode_start = time .perf_counter ()
358
+ generated_text = pipe_tokenizer .decode (generated_tokens )
359
+ tok_decode_end = time .perf_counter ()
360
+ tok_decode_time = (tok_decode_end - tok_decode_start ) * 1000
361
+ # Only text_gen need to minus length of input_data, because generated_text may include input_text
362
+ num_tokens = 0
363
+ result_md5_list = []
364
+ for bs_idx in range (args ['batch_size' ]):
365
+ generated_text_len = len (generated_tokens [bs_idx ])
366
+ num_tokens += generated_text_len
367
+ if generated_text_len > max_gen_tokens :
368
+ log .error ('Output token size is over max output token size!' )
369
+ result_text = generated_text [bs_idx ]
370
+ if args ["output_dir" ] is not None :
371
+ llm_bench_utils .output_file .output_gen_text (result_text , args , model_precision , prompt_index , num , bs_idx , proc_id )
372
+ result_md5_list .append (hashlib .new ("md5" , result_text .encode (), usedforsecurity = False ).hexdigest ())
373
+ if len (md5_list [num ]) == 0 :
374
+ md5_list [num ] = {prompt_index : result_md5_list }
375
+ else :
376
+ md5_list [num ][prompt_index ] = result_md5_list
377
+ per_token_time = generation_time * 1000 / (num_tokens / args ['batch_size' ])
378
+ tm_list = streamer .get_time_list ()
379
+ log .debug ('latency of all tokens:' )
380
+ [log .debug ('[{}]{:.4f}' .format (idx , tm )) for idx , tm in enumerate (tm_list )]
381
+ iter_data = gen_iterate_data (
382
+ num ,
383
+ input_token_size * args ['batch_size' ],
384
+ len (tm_list ),
385
+ num_tokens ,
386
+ generation_time ,
387
+ per_token_time ,
388
+ result_md5_list ,
389
+ max_rss_mem = max_rss_mem_consumption ,
390
+ max_shared_mem = max_shared_mem_consumption ,
391
+ max_uss_mem = max_uss_mem_consumption ,
392
+ prompt_idx = prompt_index ,
393
+ tokenization_time = (tok_encode_time , tok_decode_time )
394
+ )
395
+ iter_data_list .append (iter_data )
396
+ llm_bench_utils .metrics_print .print_metrics (
397
+ num ,
398
+ iter_data ,
399
+ tm_list ,
400
+ [],
401
+ warm_up = (num == 0 ),
402
+ max_rss_mem = max_rss_mem_consumption ,
403
+ max_shared_mem = max_shared_mem_consumption ,
404
+ max_uss_mem = max_uss_mem_consumption ,
405
+ tokenization_time = (tok_encode_time , tok_decode_time ),
406
+ batch_size = args ['batch_size' ]
407
+ )
408
+ if num > 0 :
409
+ prev_md5 = md5_list [num - 1 ][prompt_index ]
410
+ if result_md5_list != prev_md5 :
411
+ log .warning (f"[{ num } ] Prompt[{ prompt_index } ]'s md5 { result_md5_list } "
412
+ f"is different from md5 of the { num - 1 } iteration { prev_md5 } " )
413
+ llm_bench_utils .metrics_print .print_generated (num , warm_up = (num == 0 ), generated = generated_text [0 ])
414
+ if num == 1 :
415
+ # if the device is CPU, throw exception
416
+ if args ['devices' ].lower ().startswith ('cpu' ) is True :
417
+ assert (result_md5_list == prev_md5 )
418
+ else :
419
+ # throw exception
420
+ assert (result_md5_list == prev_md5 )
421
+ else :
422
+ llm_bench_utils .metrics_print .print_generated (num , warm_up = (num == 0 ), generated = generated_text [0 ])
423
+ streamer .reset ()
424
+
425
+
322
426
def run_text_generation_benchmark (model_path , framework , device , args , num_iters ):
323
427
model , tokenizer , pretrain_time , bench_hook , use_genai = FW_UTILS [framework ].create_text_gen_model (model_path , device , ** args )
324
428
model_precision = llm_bench_utils .model_utils .get_model_precision (model_path .parts )
@@ -341,7 +445,12 @@ def run_text_generation_benchmark(model_path, framework, device, args, num_iters
341
445
f'prompt nums: { len (text_list )} , prompt idx: { prompt_idx_list } ' )
342
446
343
447
# if num_iters == 0, just output warm-up data
344
- text_gen_fn = run_text_generation if not use_genai else run_text_generation_genai
448
+ if not use_genai :
449
+ text_gen_fn = run_text_generation
450
+ elif bench_hook is not None :
451
+ text_gen_fn = run_text_generation_genai_with_stream
452
+ else :
453
+ text_gen_fn = run_text_generation_genai
345
454
proc_id = os .getpid ()
346
455
if args ['subsequent' ] is False :
347
456
for num in range (num_iters + 1 ):
0 commit comments