Skip to content

Commit a907b5f

Browse files
authored
Whisper pipeline: add perf metrics (openvinotoolkit#971)
This PR adds: - [x] support perf metrics Common Todos for Whisper support: - [ ] Long-form audio support with [parallel chunking](https://huggingface.co/blog/asr-chunking). - [ ] update documentation - [ ] add cpp, python samples tests - [ ] support timestamps streaming - [ ] expose only meaningful parameters in `GenerationConfig` (`task`, `language`, `return_timestamps`, etc) - [ ] Move all whisper pipeline files to dedicated subfolder - [ ] Whisper pipeline doesn't need tokenizer, it uses detokenizer only. Implement detokenizer only initialization for `ov::genai::Tokenizer` - [ ] Check discrete GPU. Integrated GPU works as expected. - [ ] Investigate use of `RemoteTensor` for GPU - [ ] Add batch - [ ] Add sampler, inherit WhisperGenerationConfig from GenerationConfig - [ ] Investigate language autodetection with single decoder (without past) call - [ ] Update python bindings cmake to include whole directory instead of explicit list of files - [ ] Add samples with audio preparation examples - [ ] Add links to audio files so users can download them in samples - [ ] Move supported models list from samples README to common supported models section - [ ] Avoid building GenAI in each tests job as it takes a lot of time - [ ] Double check FP32 support - [ ] Fix tests sporadic fails. Sometimes whisper model cannot be downloaded from HF due to network issues - [ ] Fix stop criteria. Current approach stops on eos_token which is no speech token. But there could be more speech tokens further which are wrongly skipped now - [ ] Fix distil whisper accuracy, match with HF - [ ] Fix en models accuracy with timestamps, match with HF - [ ] Try to trim input_ids cache between chunks for long-form audio to match HF Completed: - [x] support different languages, language autodetection - [x] support translation - [x] support timestamps - [x] Long-form audio support with sequential chunking. Current limitations: - No resampling during preprocessing. Input raw speech should have 16k Hz sampling rate - No normalization during preprocessing. Input raw speech should be normalized to near [-1, 1] range Tickets: CVS-147994, CVS-146010, CVS-152523
1 parent 1fdf96e commit a907b5f

File tree

8 files changed

+151
-53
lines changed

8 files changed

+151
-53
lines changed

samples/cpp/whisper_speech_recognition/whisper_speech_recognition.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ int main(int argc, char* argv[]) try {
3535
for (auto& chunk : *result.chunks) {
3636
std::cout << "timestamps: [" << chunk.start_ts << ", " << chunk.end_ts << "] text: " << chunk.text << "\n";
3737
}
38+
3839
} catch (const std::exception& error) {
3940
try {
4041
std::cerr << error.what() << '\n';

src/cpp/src/perf_metrics.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ void PerfMetrics::evaluate_statistics(std::optional<TimePoint> start_time) {
9797
if (m_evaluated){
9898
return;
9999
}
100-
// If start_tiem is specified then recalcualte durations according to start times and calculate statistics only after that.
100+
// If start_item is specified then recalcualte durations according to start times and calculate statistics only after that.
101101
if (start_time.has_value()) {
102102
auto start_time_val = *start_time;
103103
auto& tok_times = raw_metrics.m_new_token_times;

src/cpp/src/whisper/timestamps.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ ov::genai::ExtractedSegments extract_segments(const std::vector<int64_t>& tokens
7272
tokens.end());
7373
}
7474

75-
// last timestamps generated in pairs <ts><ts><eos> -> speech segment continuation to the next chunk -> token_start will have value
76-
// single ending timestamp <ts><eos> -> no more speech till the end of current chunk -> set offset to the end of frame
75+
// last timestamps generated in pairs <ts><ts><eos> -> speech segment continuation to the next chunk -> token_start
76+
// will have value single ending timestamp <ts><eos> -> no more speech till the end of current chunk -> set offset
77+
// to the end of frame
7778
if (!token_start.has_value()) {
7879
extracted_segments.last_offset = nb_max_frames;
7980
}

src/cpp/src/whisper/whisper.cpp

+61-17
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "../utils.hpp"
1212
#include "logit_processor.hpp"
13+
#include "openvino/genai/perf_metrics.hpp"
1314
#include "openvino/genai/streamer_base.hpp"
1415
#include "openvino/genai/whisper_generation_config.hpp"
1516
#include "openvino/genai/whisper_pipeline.hpp"
@@ -18,12 +19,15 @@
1819
#include "whisper_feature_extractor.hpp"
1920
#include "whisper_models.hpp"
2021

22+
using ov::genai::MicroSeconds;
23+
2124
namespace {
2225

2326
ov::Tensor encode(ov::InferRequest& request,
2427
std::vector<float>& mel_data,
2528
const size_t feature_size,
26-
const size_t nb_max_frames) {
29+
const size_t nb_max_frames,
30+
ov::genai::RawPerfMetrics& raw_metrics) {
2731
OPENVINO_ASSERT(mel_data.size() == feature_size * nb_max_frames,
2832
"Mel spectrogram required size: ",
2933
feature_size,
@@ -37,7 +41,10 @@ ov::Tensor encode(ov::InferRequest& request,
3741

3842
request.set_tensor("input_features", input_tensor);
3943

44+
const auto infer_start = std::chrono::steady_clock::now();
4045
request.infer();
46+
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);
47+
raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms);
4148

4249
// reset input tensor
4350
request.set_tensor("input_features", ov::Tensor(ov::element::f32, {0, feature_size, nb_max_frames}));
@@ -72,18 +79,30 @@ void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) {
7279
}
7380
}
7481

82+
void infer_with_perf_metrics(ov::InferRequest& request, ov::genai::RawPerfMetrics& raw_metrics) {
83+
const auto infer_start = std::chrono::steady_clock::now();
84+
request.infer();
85+
const auto infer_end = std::chrono::steady_clock::now();
86+
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(infer_end - infer_start);
87+
raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms);
88+
raw_metrics.m_token_infer_durations.emplace_back(infer_ms);
89+
raw_metrics.m_new_token_times.emplace_back(infer_end);
90+
raw_metrics.m_batch_sizes.emplace_back(1);
91+
}
92+
7593
int64_t decode(ov::Tensor& encoder_hidden_state,
7694
ov::InferRequest& decoder,
7795
std::vector<int64_t>& input_ids,
7896
const ov::genai::WhisperGenerationConfig& config,
97+
ov::genai::RawPerfMetrics& raw_metrics,
7998
const bool apply_logit_processors = true,
8099
const bool return_timestamps = false) {
81100
decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});
82101

83102
ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, input_ids.data());
84103
decoder.set_tensor("input_ids", input_ids_tensor);
85104

86-
decoder.infer();
105+
infer_with_perf_metrics(decoder, raw_metrics);
87106

88107
auto output_tensor = decoder.get_tensor("logits");
89108

@@ -106,6 +125,7 @@ int64_t decode_with_past(ov::Tensor& encoder_hidden_state,
106125
int64_t input_id,
107126
const size_t cache_position,
108127
const ov::genai::WhisperGenerationConfig& config,
128+
ov::genai::RawPerfMetrics& raw_metrics,
109129
const bool return_timestamps,
110130
const std::vector<int64_t>& generated_tokens) {
111131
decoder_with_past.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});
@@ -118,7 +138,7 @@ int64_t decode_with_past(ov::Tensor& encoder_hidden_state,
118138
cache_position_tensor.set_shape({1});
119139
cache_position_tensor.data<int64_t>()[0] = cache_position;
120140

121-
decoder_with_past.infer();
141+
infer_with_perf_metrics(decoder_with_past, raw_metrics);
122142

123143
auto output_tensor = decoder_with_past.get_tensor("logits");
124144

@@ -137,7 +157,17 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
137157
ov::InferRequest decoder,
138158
const ov::genai::WhisperGenerationConfig& config) {
139159
std::vector<int64_t> input_ids{config.decoder_start_token_id};
140-
int64_t output_token = decode(encoder_hidden_state, decoder, input_ids, config, false, false);
160+
161+
decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});
162+
163+
ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, input_ids.data());
164+
decoder.set_tensor("input_ids", input_ids_tensor);
165+
166+
decoder.infer();
167+
168+
auto output_tensor = decoder.get_tensor("logits");
169+
170+
int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);
141171

142172
return output_token;
143173
}
@@ -181,8 +211,10 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
181211
std::vector<int64_t> init_ids,
182212
const size_t max_new_tokens,
183213
const bool return_timestamps,
214+
ov::genai::RawPerfMetrics& raw_metrics,
184215
const std::shared_ptr<ov::genai::StreamerBase> streamer) {
185-
int64_t output_token = decode(encoder_hidden_state, models.decoder, init_ids, config, true, return_timestamps);
216+
int64_t output_token =
217+
decode(encoder_hidden_state, models.decoder, init_ids, config, raw_metrics, true, return_timestamps);
186218

187219
std::vector<int64_t> output_tokens{output_token};
188220

@@ -203,6 +235,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
203235
output_tokens.back(),
204236
init_ids.size() + output_tokens.size() - 1,
205237
config,
238+
raw_metrics,
206239
return_timestamps,
207240
output_tokens);
208241

@@ -230,23 +263,30 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
230263
namespace ov {
231264
namespace genai {
232265

233-
std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_generate(
234-
const ov::genai::WhisperGenerationConfig& config,
235-
const ov::genai::WhisperConfig& model_config,
236-
const RawSpeechInput& raw_speech,
237-
ov::genai::WhisperInitializedModels& models,
238-
WhisperFeatureExtractor& feature_extractor,
239-
const std::shared_ptr<StreamerBase> streamer) {
266+
WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& config,
267+
const ov::genai::WhisperConfig& model_config,
268+
const RawSpeechInput& raw_speech,
269+
ov::genai::WhisperInitializedModels& models,
270+
WhisperFeatureExtractor& feature_extractor,
271+
const std::shared_ptr<StreamerBase> streamer) {
240272
auto input_features = feature_extractor.extract(raw_speech);
241273

242274
const bool is_shortform = input_features.n_frames <= feature_extractor.nb_max_frames;
243275
// long-form audio processing requires timestamps to be enabled
244276
const bool return_timestamps = config.return_timestamps || !is_shortform;
245277

246-
std::vector<int64_t> init_ids;
247-
std::vector<int64_t> output_tokens;
248278
size_t max_new_tokens = config.get_max_new_tokens();
249279

280+
WhisperGenerateResult result;
281+
RawPerfMetrics& raw_metrics = result.perf_metrics.raw_metrics;
282+
result.perf_metrics.num_input_tokens = 0;
283+
raw_metrics.m_new_token_times.reserve(max_new_tokens);
284+
raw_metrics.m_batch_sizes.reserve(max_new_tokens);
285+
raw_metrics.m_token_infer_durations.reserve(max_new_tokens);
286+
raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}};
287+
288+
std::vector<int64_t> init_ids;
289+
std::vector<int64_t>& output_tokens = result.output_tokens;
250290
std::vector<Segment> segments;
251291

252292
// 0.02 by default
@@ -263,7 +303,8 @@ std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_gen
263303
ov::Tensor hidden_state_tensor = encode(models.encoder,
264304
input_features_chunk,
265305
feature_extractor.feature_size,
266-
feature_extractor.nb_max_frames);
306+
feature_extractor.nb_max_frames,
307+
raw_metrics);
267308

268309
// prepare init_ids just once for whole input
269310
if (init_ids.empty()) {
@@ -276,6 +317,7 @@ std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_gen
276317
init_ids,
277318
max_new_tokens - output_tokens.size(),
278319
return_timestamps,
320+
raw_metrics,
279321
streamer);
280322

281323
if (return_timestamps) {
@@ -310,10 +352,12 @@ std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_gen
310352

311353
// if return_timestamps wasn't enabled by user
312354
if (!config.return_timestamps) {
313-
return {output_tokens, std::nullopt};
355+
return result;
314356
}
315357

316-
return {output_tokens, segments};
358+
result.segments = segments;
359+
360+
return result;
317361
}
318362
} // namespace genai
319363
} // namespace ov

src/cpp/src/whisper/whisper.hpp

+12-7
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,18 @@ struct Segment {
2020
std::vector<int64_t> m_tokens;
2121
};
2222

23-
std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_generate(
24-
const ov::genai::WhisperGenerationConfig& config,
25-
const ov::genai::WhisperConfig& model_config,
26-
const ov::genai::RawSpeechInput& raw_speech,
27-
ov::genai::WhisperInitializedModels& models,
28-
ov::genai::WhisperFeatureExtractor& feature_extractor,
29-
const std::shared_ptr<StreamerBase> streamer);
23+
struct WhisperGenerateResult {
24+
std::vector<int64_t> output_tokens;
25+
std::optional<std::vector<Segment>> segments = std::nullopt;
26+
PerfMetrics perf_metrics;
27+
};
28+
29+
WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& config,
30+
const ov::genai::WhisperConfig& model_config,
31+
const ov::genai::RawSpeechInput& raw_speech,
32+
ov::genai::WhisperInitializedModels& models,
33+
ov::genai::WhisperFeatureExtractor& feature_extractor,
34+
const std::shared_ptr<StreamerBase> streamer);
3035

3136
} // namespace genai
3237
} // namespace ov

src/cpp/src/whisper_pipeline.cpp

+34-19
Original file line numberDiff line numberDiff line change
@@ -93,28 +93,43 @@ class WhisperPipeline::Impl {
9393
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
9494
}
9595

96-
auto [output_tokens, segments] = ov::genai::whisper_generate(config,
97-
m_model_config,
98-
raw_speech_input,
99-
m_models,
100-
m_feature_extractor,
101-
streamer_ptr);
102-
103-
WhisperDecodedResults decoded_results{std::vector{m_tokenizer.decode(output_tokens)}, std::vector{1.f}};
104-
if (!segments.has_value()) {
105-
return decoded_results;
96+
auto generate_result = ov::genai::whisper_generate(config,
97+
m_model_config,
98+
raw_speech_input,
99+
m_models,
100+
m_feature_extractor,
101+
streamer_ptr);
102+
auto decode_start_time = std::chrono::steady_clock::now();
103+
WhisperDecodedResults result{std::vector{m_tokenizer.decode(generate_result.output_tokens)}, std::vector{1.f}};
104+
generate_result.perf_metrics.raw_metrics.detokenization_durations.emplace_back(
105+
PerfMetrics::get_microsec(std::chrono::steady_clock::now() - decode_start_time));
106+
107+
result.perf_metrics = generate_result.perf_metrics;
108+
auto& segments = generate_result.segments;
109+
110+
if (segments.has_value()) {
111+
std::vector<WhisperDecodedResultChunk> chunks;
112+
chunks.reserve((*segments).size());
113+
114+
for (auto& segment : *segments) {
115+
decode_start_time = std::chrono::steady_clock::now();
116+
chunks.push_back(
117+
WhisperDecodedResultChunk{segment.m_start, segment.m_end, m_tokenizer.decode(segment.m_tokens)});
118+
result.perf_metrics.raw_metrics.detokenization_durations.emplace_back(
119+
PerfMetrics::get_microsec(std::chrono::steady_clock::now() - decode_start_time));
120+
}
121+
122+
result.chunks = chunks;
106123
}
107124

108-
std::vector<WhisperDecodedResultChunk> chunks;
109-
chunks.reserve((*segments).size());
125+
auto& metrics = result.perf_metrics;
126+
metrics.load_time = this->m_load_time_ms;
127+
auto stop_time = std::chrono::steady_clock::now();
128+
metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time));
129+
result.perf_metrics.raw_metrics.tokenization_durations.emplace_back(MicroSeconds(0.0f));
130+
metrics.evaluate_statistics(start_time);
110131

111-
for (auto& segment : *segments) {
112-
chunks.push_back(
113-
WhisperDecodedResultChunk{segment.m_start, segment.m_end, m_tokenizer.decode(segment.m_tokens)});
114-
}
115-
116-
decoded_results.chunks = chunks;
117-
return decoded_results;
132+
return result;
118133
}
119134
};
120135

src/python/py_generate_pipeline.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
637637
.def("get_num_input_tokens", &PerfMetrics::get_num_input_tokens)
638638
.def("get_ttft", &PerfMetrics::get_ttft)
639639
.def("get_tpot", &PerfMetrics::get_tpot)
640+
.def("get_ipot", &PerfMetrics::get_ipot)
640641
.def("get_throughput", &PerfMetrics::get_throughput)
641642
.def("get_generate_duration", &PerfMetrics::get_generate_duration)
643+
.def("get_inference_duration", &PerfMetrics::get_inference_duration)
642644
.def("get_tokenization_duration", &PerfMetrics::get_tokenization_duration)
643645
.def("get_detokenization_duration", &PerfMetrics::get_detokenization_duration)
644646
.def("__add__", &PerfMetrics::operator+)

tests/python_tests/test_whisper_generate_api.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,25 @@ def test_whisper_on_hf_dataset(model_descr, dataset_id):
131131
compare_genai_and_opt_pipelines(opt_pipe, genai_pipe, dataset_id)
132132

133133

134+
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
135+
@pytest.mark.parametrize(
136+
"test_sample",
137+
get_samples_from_dataset(language="en", length=1),
138+
)
139+
@pytest.mark.precommit
140+
def test_smoke(model_descr, test_sample):
141+
model_id, path, opt_pipe, pipe = read_whisper_model(model_descr)
142+
143+
expected = opt_pipe(test_sample)
144+
145+
genai_result = pipe.generate(test_sample)
146+
147+
assert genai_result.texts[0] == expected["text"]
148+
149+
assert "chunks" not in expected
150+
assert genai_result.chunks == None
151+
152+
134153
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
135154
@pytest.mark.precommit
136155
def test_whisper_config_constructor(model_descr):
@@ -509,17 +528,28 @@ def test_longform_audio(model_descr, test_sample):
509528
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
510529
@pytest.mark.parametrize(
511530
"test_sample",
512-
get_samples_from_dataset(language="en", length=1),
531+
[
532+
*get_samples_from_dataset(language="en", length=1),
533+
],
513534
)
514535
@pytest.mark.precommit
515-
def test_smoke(model_descr, test_sample):
536+
def test_perf_metrics(model_descr, test_sample):
516537
model_id, path, opt_pipe, pipe = read_whisper_model(model_descr)
517538

518-
expected = opt_pipe(test_sample)
539+
result = pipe.generate(test_sample)
519540

520-
genai_result = pipe.generate(test_sample)
541+
perf_metrics = result.perf_metrics
521542

522-
assert genai_result.texts[0] == expected["text"]
543+
assert perf_metrics is not None
523544

524-
assert "chunks" not in expected
525-
assert genai_result.chunks == None
545+
assert perf_metrics.get_load_time() > 0
546+
assert perf_metrics.get_num_generated_tokens() > 0
547+
assert perf_metrics.get_num_input_tokens() == 0
548+
assert perf_metrics.get_ttft().mean > 0
549+
assert perf_metrics.get_tpot().mean > 0
550+
assert perf_metrics.get_ipot().mean > 0
551+
assert perf_metrics.get_throughput().mean > 0
552+
assert perf_metrics.get_inference_duration().mean > 0
553+
assert perf_metrics.get_generate_duration().mean > 0
554+
assert perf_metrics.get_tokenization_duration().mean == 0
555+
assert perf_metrics.get_detokenization_duration().mean > 0

0 commit comments

Comments
 (0)