Skip to content

Commit ac263eb

Browse files
committed
Print time of encoder and decoder of each loop
1 parent 7128c03 commit ac263eb

File tree

4 files changed

+97
-85
lines changed

4 files changed

+97
-85
lines changed

llm_bench/python/benchmark.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -738,9 +738,16 @@ def run_speech_2txt_generation(pipe, args, num, md5_list, prompt_id, audio_promp
738738
prompt_idx=prompt_id,
739739
)
740740
iter_data_list.append(iter_data)
741+
tm_list = []
742+
tm_infer_list = []
743+
if whisper_hook is not None:
744+
tm_list = whisper_hook.get_time_list()
745+
tm_infer_list = whisper_hook.get_time_infer_list()
741746
llm_bench_utils.metrics_print.print_metrics(
742747
num,
743748
iter_data,
749+
tm_list,
750+
tm_infer_list,
744751
warm_up=(num == 0),
745752
max_rss_mem=max_rss_mem_consumption,
746753
max_shared_mem=max_shared_mem_consumption,
@@ -773,8 +780,8 @@ def run_speech_2txt_benchmark(model_path, framework, device, args, num_iters):
773780
for audio_prompt in input_audio_prompt_list:
774781
if args['prompt'] is None and args['prompt_file'] is None:
775782
raise RuntimeError('==Failure image is empty ==')
776-
elif args['prompt_file'] is not None:
777-
audio_prompt['prompt'] = os.path.join(os.path.dirname(args['prompt_file']), audio_prompt['prompt'].replace('./', ''))
783+
elif args['prompt_file'] is not None and len(args['prompt_file']) > 0:
784+
audio_prompt['prompt'] = os.path.join(os.path.dirname(args['prompt_file'][0]), audio_prompt['prompt'].replace('./', ''))
778785
audio_prompt['prompt'] = Path(audio_prompt['prompt'])
779786
audios_prompt_list.append(audio_prompt)
780787
if args['prompt_index'] is None:
@@ -800,7 +807,8 @@ def run_speech_2txt_benchmark(model_path, framework, device, args, num_iters):
800807
)
801808
if framework == "ov":
802809
whisper_hook.new_text_encoder(pipe)
803-
whisper_hook.new_text_decoder(pipe)
810+
whisper_hook.new_generate(pipe)
811+
whisper_hook.new_text_sample(pipe)
804812
md5_list = {num : {} for num in range(num_iters + 1)}
805813
for num in range(num_iters + 1):
806814
for idx, audio_prompt in enumerate(audio_list):
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,86 @@
11
import time
2+
import copy
3+
import llm_bench_utils.hook_greedy_search
24

35

46
class WhisperHook:
57
def __init__(self):
6-
self.text_encoder_time = 0
7-
self.text_decoder_time = 0
8-
self.text_enc_time_list = []
9-
self.text_dec_time_list = []
10-
self.text_encoder_infer_count = 0
11-
self.text_decoder_infer_count = 0
8+
self.enc_infer_count = 0
9+
self.time_data = []
10+
self.greedy_hook = None
1211

13-
def get_text_encoder_latency(self):
14-
return (self.text_encoder_time / self.text_encoder_infer_count) * 1000 if self.text_encoder_infer_count > 0 else 0
15-
16-
def get_1st_text_enc_latency(self):
17-
return self.text_enc_time_list[0] * 1000 if len(self.text_enc_time_list) > 0 else 0
18-
19-
def get_2nd_text_enc_latency(self):
20-
return sum(self.text_enc_time_list[1:]) / (len(self.text_enc_time_list) - 1) * 1000 if len(self.text_enc_time_list) > 1 else 0
21-
22-
def get_1st_text_dec_latency(self):
23-
return self.text_dec_time_list[0] * 1000 if len(self.text_dec_time_list) > 0 else 0
24-
25-
def get_2nd_text_dec_latency(self):
26-
return sum(self.text_dec_time_list[1:]) / (len(self.text_dec_time_list) - 1) * 1000 if len(self.text_dec_time_list) > 1 else 0
27-
28-
def get_text_dec_latency(self):
29-
return (sum(self.text_dec_time_list) / len(self.text_dec_time_list)) * 1000 if len(self.text_dec_time_list) > 0 else 0
30-
31-
def get_text_decoder_latency(self):
32-
return (self.text_decoder_time / self.text_decoder_infer_count) * 1000 if self.text_decoder_infer_count > 0 else 0
12+
def get_time_list(self):
13+
"""return first loop token time
14+
"""
15+
time_list = []
16+
if len(self.time_data) > 0:
17+
time_list = copy.deepcopy(self.time_data[0]['dec_token_time'])
18+
time_list.insert(0, self.time_data[0]['enc_infer_time'])
19+
return time_list
3320

34-
def get_text_encoder_step_count(self):
35-
return self.text_encoder_infer_count
36-
37-
def get_text_decoder_step_count(self):
38-
return self.text_decoder_infer_count
21+
def get_time_infer_list(self):
22+
"""return first loop infer time
23+
"""
24+
time_infer_list = []
25+
if len(self.time_data) > 0:
26+
time_infer_list = copy.deepcopy(self.time_data[0]['dec_infer_time'])
27+
time_infer_list.insert(0, self.time_data[0]['enc_infer_time'])
28+
return time_infer_list
29+
30+
def get_whisper_latency(self, iter):
31+
str = ''
32+
for idx, data in enumerate(self.time_data):
33+
enc_infer_time = data['enc_infer_time'] * 1000
34+
dec_token_count = len(data['dec_token_time'])
35+
dec_infer_count = len(data['dec_infer_time'])
36+
dec_token_time = sum(data['dec_token_time']) / dec_token_count * 1000 if dec_token_count > 1 else 0
37+
dec_infer_time = sum(data['dec_infer_time']) / dec_infer_count * 1000 if dec_infer_count > 1 else 0
38+
str += f"[{iter}][{idx}] encoder token latency: {enc_infer_time:.2f} ms/token, " \
39+
f"decoder tokens latency: {dec_token_time:.2f} ms/token, " \
40+
f"decoder infers latency: {dec_infer_time:.2f} ms/infer, " \
41+
f"decoder tokens count: {dec_token_count}, " \
42+
f"decoder infers count: {dec_infer_count}"
43+
if idx < len(self.time_data) - 1:
44+
str += '\n'
45+
return str
3946

4047
def clear_statistics(self):
41-
self.text_encoder_time = 0
42-
self.text_decoder_time = 0
43-
self.text_encoder_infer_count = 0
44-
self.text_decoder_infer_count = 0
45-
self.text_enc_time_list = []
46-
self.text_dec_time_list = []
48+
self.enc_infer_count = 0
49+
self.time_data.clear()
50+
self.greedy_hook.clear_time_list()
51+
self.greedy_hook.clear_time_infer_list()
4752

4853
def new_text_encoder(self, pipe):
4954
old_text_encoder = pipe.model.encoder.request
5055

5156
def my_text_encoder(inputs, share_inputs=True, share_outputs=True):
57+
loop_data = {}
5258
t1 = time.time()
5359
r = old_text_encoder(inputs, share_inputs, share_outputs)
5460
t2 = time.time()
5561
text_encoder_time = t2 - t1
56-
self.text_enc_time_list.append(text_encoder_time)
57-
self.text_encoder_time += text_encoder_time
58-
self.text_encoder_infer_count += 1
62+
loop_data['enc_infer_time'] = text_encoder_time
63+
self.time_data.append(loop_data)
64+
self.enc_infer_count += 1
5965
return r
6066
pipe.model.encoder.request = my_text_encoder
6167

62-
def new_text_decoder(self, pipe):
63-
old_text_decoder = pipe.model.forward
68+
def new_text_sample(self, pipe):
69+
self.greedy_hook = llm_bench_utils.hook_greedy_search.GreedySearchHook()
70+
self.greedy_hook.new_forward(pipe.model)
6471

65-
def my_text_decoder(*args, **kwargs):
66-
t1 = time.time()
67-
r = old_text_decoder(*args, **kwargs)
68-
t2 = time.time()
69-
text_decoder_time = t2 - t1
70-
self.text_dec_time_list.append(text_decoder_time)
71-
self.text_decoder_time += text_decoder_time
72-
self.text_decoder_infer_count += 1
72+
def new_generate(self, pipe):
73+
old_generate = pipe.model.generate
74+
def my_generate(attention_mask, **kwargs):
75+
r = old_generate(attention_mask, **kwargs)
76+
self.set_decoder_time_data()
7377
return r
74-
pipe.model.forward = my_text_decoder
78+
pipe.model.generate = my_generate
79+
80+
def set_decoder_time_data(self):
81+
if self.enc_infer_count > 0:
82+
prev_data = self.time_data[self.enc_infer_count - 1]
83+
prev_data['dec_token_time'] = copy.deepcopy(self.greedy_hook.get_time_list())
84+
prev_data['dec_infer_time'] = copy.deepcopy(self.greedy_hook.get_time_infer_list())
85+
self.greedy_hook.clear_time_list()
86+
self.greedy_hook.clear_time_infer_list()

llm_bench/python/llm_bench_utils/metrics_print.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def print_metrics(
5757
if stable_diffusion is not None:
5858
print_stable_diffusion_infer_latency(iter_str, iter_data, stable_diffusion)
5959
if whisper is not None:
60-
print_whisper_infer_latency(iter_str, iter_data, whisper)
60+
print_whisper_infer_latency(iter_str, whisper)
6161
output_str = ''
6262
if max_rss_mem != '' and max_rss_mem > -1:
6363
output_str += 'Max rss memory cost: {:.2f}MBytes, '.format(max_rss_mem)
@@ -102,17 +102,8 @@ def print_stable_diffusion_infer_latency(iter_str, iter_data, stable_diffusion):
102102
f"vae decoder step count: {stable_diffusion.get_vae_decoder_step_count()}",)
103103

104104

105-
def print_whisper_infer_latency(iter_str, iter_data, whisper):
106-
iter_data['first_token_latency'] = whisper.get_1st_text_dec_latency()
107-
iter_data['other_tokens_avg_latency'] = whisper.get_2nd_text_dec_latency()
108-
iter_data['first_token_infer_latency'] = iter_data['first_token_latency']
109-
iter_data['other_tokens_infer_avg_latency'] = iter_data['other_tokens_avg_latency']
110-
log.info(f"[{iter_str}] First token of encoder latency: {whisper.get_1st_text_enc_latency():.2f} ms/token, "
111-
f"other tokens of encoder latency: {whisper.get_2nd_text_enc_latency():.2f} ms/token, "
112-
f"First token of decoder latency: {iter_data['first_token_latency']:.2f} ms/token, "
113-
f"other tokens of decoder latency: {iter_data['other_tokens_avg_latency']:.2f} ms/token, "
114-
f"text encoder infer count: {whisper.get_text_encoder_step_count()}, "
115-
f"text decoder infer count: {whisper.get_text_decoder_step_count()}")
105+
def print_whisper_infer_latency(iter_str, whisper):
106+
print(f'{whisper.get_whisper_latency(iter_str)}')
116107

117108

118109
def print_ldm_unet_vqvae_infer_latency(iter_num, iter_data, tms=None, warm_up=False):

llm_bench/python/llm_bench_utils/model_utils.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -116,26 +116,27 @@ def get_audio_param_from_prompt_file(args):
116116
else:
117117
raise RuntimeError('== prompt should not be empty string ==')
118118
else:
119-
input_prompt = args['prompt_file']
120-
if input_prompt.endswith('.jsonl'):
121-
if os.path.exists(input_prompt):
122-
log.info(f'Read prompts from {input_prompt}')
123-
with open(input_prompt, 'r', encoding='utf-8') as f:
124-
for line in f:
125-
audio_param = {}
126-
data = json.loads(line)
127-
if 'media' in data:
128-
if data['media'] != '':
129-
audio_param['prompt'] = data['media']
119+
input_prompt_list = args['prompt_file']
120+
for input_prompt in input_prompt_list:
121+
if input_prompt.endswith('.jsonl'):
122+
if os.path.exists(input_prompt):
123+
log.info(f'Read prompts from {input_prompt}')
124+
with open(input_prompt, 'r', encoding='utf-8') as f:
125+
for line in f:
126+
audio_param = {}
127+
data = json.loads(line)
128+
if 'media' in data:
129+
if data['media'] != '':
130+
audio_param['prompt'] = data['media']
131+
else:
132+
raise RuntimeError(f'== prompt should not be empty string in prompt file:{input_prompt} ==')
130133
else:
131-
raise RuntimeError(f'== prompt should not be empty string in prompt file:{input_prompt} ==')
132-
else:
133-
raise RuntimeError(f'== key word "media" does not exist in prompt file:{input_prompt} ==')
134-
audio_param_list.append(audio_param)
134+
raise RuntimeError(f'== key word "media" does not exist in prompt file:{input_prompt} ==')
135+
audio_param_list.append(audio_param)
136+
else:
137+
raise RuntimeError(f'== The prompt file:{input_prompt} does not exist ==')
135138
else:
136-
raise RuntimeError(f'== The prompt file:{input_prompt} does not exist ==')
137-
else:
138-
raise RuntimeError(f'== The prompt file:{input_prompt} should be ended with .jsonl ==')
139+
raise RuntimeError(f'== The prompt file:{input_prompt} should be ended with .jsonl ==')
139140
return audio_param_list
140141

141142

@@ -307,4 +308,4 @@ def get_model_precision(model_name_list):
307308
break
308309
if model_precision != 'unknown':
309310
break
310-
return model_precision
311+
return '' if model_precision == 'unknown' else model_precision

0 commit comments

Comments
 (0)