|
6 | 6 |
|
7 | 7 | def print_metrics(
|
8 | 8 | iter_num, iter_data, tms=None, tms_infer=None, warm_up=False, max_rss_mem=-1, max_shared_mem=-1,
|
9 |
| - max_uss_mem=-1, stable_diffusion=None, tokenization_time=None, batch_size=1, whisper = None |
| 9 | + max_uss_mem=-1, stable_diffusion=None, tokenization_time=None, batch_size=1, whisper = None, prompt_idx=-1 |
10 | 10 | ):
|
11 | 11 | iter_str = str(iter_num)
|
12 | 12 | if warm_up:
|
@@ -57,7 +57,7 @@ def print_metrics(
|
57 | 57 | if stable_diffusion is not None:
|
58 | 58 | print_stable_diffusion_infer_latency(iter_str, iter_data, stable_diffusion)
|
59 | 59 | if whisper is not None:
|
60 |
| - print_whisper_infer_latency(iter_str, whisper) |
| 60 | + print_whisper_infer_latency(iter_str, whisper, prompt_idx) |
61 | 61 | output_str = ''
|
62 | 62 | if max_rss_mem != '' and max_rss_mem > -1:
|
63 | 63 | output_str += 'Max rss memory cost: {:.2f}MBytes, '.format(max_rss_mem)
|
@@ -102,8 +102,8 @@ def print_stable_diffusion_infer_latency(iter_str, iter_data, stable_diffusion):
|
102 | 102 | f"vae decoder step count: {stable_diffusion.get_vae_decoder_step_count()}",)
|
103 | 103 |
|
104 | 104 |
|
105 |
| -def print_whisper_infer_latency(iter_str, whisper): |
106 |
| - print(f'{whisper.get_whisper_latency(iter_str)}') |
| 105 | +def print_whisper_infer_latency(iter_str, whisper, prompt_idx): |
| 106 | + print(f'{whisper.get_whisper_latency(iter_str, prompt_idx)}') |
107 | 107 |
|
108 | 108 |
|
109 | 109 | def print_ldm_unet_vqvae_infer_latency(iter_num, iter_data, tms=None, warm_up=False):
|
|
0 commit comments