1
1
import time
2
+ import copy
3
+ import llm_bench_utils .hook_greedy_search
2
4
3
5
4
6
class WhisperHook :
5
7
def __init__ (self ):
6
- self .text_encoder_time = 0
7
- self .text_decoder_time = 0
8
- self .text_enc_time_list = []
9
- self .text_dec_time_list = []
10
- self .text_encoder_infer_count = 0
11
- self .text_decoder_infer_count = 0
8
+ self .enc_infer_count = 0
9
+ self .time_data = []
10
+ self .greedy_hook = None
12
11
13
- def get_text_encoder_latency (self ):
14
- return (self .text_encoder_time / self .text_encoder_infer_count ) * 1000 if self .text_encoder_infer_count > 0 else 0
15
-
16
- def get_1st_text_enc_latency (self ):
17
- return self .text_enc_time_list [0 ] * 1000 if len (self .text_enc_time_list ) > 0 else 0
18
-
19
- def get_2nd_text_enc_latency (self ):
20
- return sum (self .text_enc_time_list [1 :]) / (len (self .text_enc_time_list ) - 1 ) * 1000 if len (self .text_enc_time_list ) > 1 else 0
21
-
22
- def get_1st_text_dec_latency (self ):
23
- return self .text_dec_time_list [0 ] * 1000 if len (self .text_dec_time_list ) > 0 else 0
24
-
25
- def get_2nd_text_dec_latency (self ):
26
- return sum (self .text_dec_time_list [1 :]) / (len (self .text_dec_time_list ) - 1 ) * 1000 if len (self .text_dec_time_list ) > 1 else 0
27
-
28
- def get_text_dec_latency (self ):
29
- return (sum (self .text_dec_time_list ) / len (self .text_dec_time_list )) * 1000 if len (self .text_dec_time_list ) > 0 else 0
30
-
31
- def get_text_decoder_latency (self ):
32
- return (self .text_decoder_time / self .text_decoder_infer_count ) * 1000 if self .text_decoder_infer_count > 0 else 0
12
+ def get_time_list (self ):
13
+ """return first loop token time
14
+ """
15
+ time_list = []
16
+ if len (self .time_data ) > 0 :
17
+ time_list = copy .deepcopy (self .time_data [0 ]['dec_token_time' ])
18
+ time_list .insert (0 , self .time_data [0 ]['enc_infer_time' ])
19
+ return time_list
33
20
34
- def get_text_encoder_step_count (self ):
35
- return self .text_encoder_infer_count
36
-
37
- def get_text_decoder_step_count (self ):
38
- return self .text_decoder_infer_count
21
+ def get_time_infer_list (self ):
22
+ """return first loop infer time
23
+ """
24
+ time_infer_list = []
25
+ if len (self .time_data ) > 0 :
26
+ time_infer_list = copy .deepcopy (self .time_data [0 ]['dec_infer_time' ])
27
+ time_infer_list .insert (0 , self .time_data [0 ]['enc_infer_time' ])
28
+ return time_infer_list
29
+
30
+ def get_whisper_latency (self , iter ):
31
+ str = ''
32
+ for idx , data in enumerate (self .time_data ):
33
+ enc_infer_time = data ['enc_infer_time' ] * 1000
34
+ dec_token_count = len (data ['dec_token_time' ])
35
+ dec_infer_count = len (data ['dec_infer_time' ])
36
+ dec_token_time = sum (data ['dec_token_time' ]) / dec_token_count * 1000 if dec_token_count > 1 else 0
37
+ dec_infer_time = sum (data ['dec_infer_time' ]) / dec_infer_count * 1000 if dec_infer_count > 1 else 0
38
+ str += f"[{ iter } ][{ idx } ] encoder token latency: { enc_infer_time :.2f} ms/token, " \
39
+ f"decoder tokens latency: { dec_token_time :.2f} ms/token, " \
40
+ f"decoder infers latency: { dec_infer_time :.2f} ms/infer, " \
41
+ f"decoder tokens count: { dec_token_count } , " \
42
+ f"decoder infers count: { dec_infer_count } "
43
+ if idx < len (self .time_data ) - 1 :
44
+ str += '\n '
45
+ return str
39
46
40
47
def clear_statistics (self ):
41
- self .text_encoder_time = 0
42
- self .text_decoder_time = 0
43
- self .text_encoder_infer_count = 0
44
- self .text_decoder_infer_count = 0
45
- self .text_enc_time_list = []
46
- self .text_dec_time_list = []
48
+ self .enc_infer_count = 0
49
+ self .time_data .clear ()
50
+ self .greedy_hook .clear_time_list ()
51
+ self .greedy_hook .clear_time_infer_list ()
47
52
48
53
def new_text_encoder (self , pipe ):
49
54
old_text_encoder = pipe .model .encoder .request
50
55
51
56
def my_text_encoder (inputs , share_inputs = True , share_outputs = True ):
57
+ loop_data = {}
52
58
t1 = time .time ()
53
59
r = old_text_encoder (inputs , share_inputs , share_outputs )
54
60
t2 = time .time ()
55
61
text_encoder_time = t2 - t1
56
- self . text_enc_time_list . append ( text_encoder_time )
57
- self .text_encoder_time += text_encoder_time
58
- self .text_encoder_infer_count += 1
62
+ loop_data [ 'enc_infer_time' ] = text_encoder_time
63
+ self .time_data . append ( loop_data )
64
+ self .enc_infer_count += 1
59
65
return r
60
66
pipe .model .encoder .request = my_text_encoder
61
67
62
- def new_text_decoder (self , pipe ):
63
- old_text_decoder = pipe .model .forward
68
+ def new_text_sample (self , pipe ):
69
+ self .greedy_hook = llm_bench_utils .hook_greedy_search .GreedySearchHook ()
70
+ self .greedy_hook .new_forward (pipe .model )
64
71
65
- def my_text_decoder (* args , ** kwargs ):
66
- t1 = time .time ()
67
- r = old_text_decoder (* args , ** kwargs )
68
- t2 = time .time ()
69
- text_decoder_time = t2 - t1
70
- self .text_dec_time_list .append (text_decoder_time )
71
- self .text_decoder_time += text_decoder_time
72
- self .text_decoder_infer_count += 1
72
+ def new_generate (self , pipe ):
73
+ old_generate = pipe .model .generate
74
+ def my_generate (attention_mask , ** kwargs ):
75
+ r = old_generate (attention_mask , ** kwargs )
76
+ self .set_decoder_time_data ()
73
77
return r
74
- pipe .model .forward = my_text_decoder
78
+ pipe .model .generate = my_generate
79
+
80
+ def set_decoder_time_data (self ):
81
+ if self .enc_infer_count > 0 :
82
+ prev_data = self .time_data [self .enc_infer_count - 1 ]
83
+ prev_data ['dec_token_time' ] = copy .deepcopy (self .greedy_hook .get_time_list ())
84
+ prev_data ['dec_infer_time' ] = copy .deepcopy (self .greedy_hook .get_time_infer_list ())
85
+ self .greedy_hook .clear_time_list ()
86
+ self .greedy_hook .clear_time_infer_list ()
0 commit comments