Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper timestamp fix #1918

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/cpp/src/whisper/timestamps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace genai {
ov::genai::ExtractedSegments extract_segments(const std::vector<int64_t>& tokens,
const ov::genai::WhisperGenerationConfig& config,
const size_t nb_max_frames,
const float time_precision) {
const float time_precision,
const float time_offset) {
ov::genai::ExtractedSegments extracted_segments;
std::optional<int64_t> token_start = std::nullopt;
const size_t timestamp_begin = config.no_timestamps_token_id + 1;
Expand Down Expand Up @@ -39,8 +40,8 @@ ov::genai::ExtractedSegments extract_segments(const std::vector<int64_t>& tokens

ov::genai::Segment segment;
segment.m_tokens = {tokens.begin() + idx_start + 1, tokens.begin() + i};
segment.m_start = (*token_start - timestamp_begin) * time_precision;
segment.m_end = (token - timestamp_begin) * time_precision;
segment.m_start = (*token_start - timestamp_begin) * time_precision + time_offset;
segment.m_end = (token - timestamp_begin) * time_precision + time_offset;
extracted_segments.segments.push_back(segment);

// each next timestamp token represents .02 time diff
Expand All @@ -62,7 +63,7 @@ ov::genai::ExtractedSegments extract_segments(const std::vector<int64_t>& tokens
if (token_start.has_value() && has_tokens_to_add && !has_previous_segments) {
ov::genai::Segment segment;
segment.m_tokens = {tokens.begin() + idx_start + 1, tokens.end()};
segment.m_start = (*token_start - timestamp_begin) * time_precision;
segment.m_start = (*token_start - timestamp_begin) * time_precision + time_offset;
segment.m_end = -1.0f;
extracted_segments.segments.push_back(segment);

Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/whisper/timestamps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ struct ExtractedSegments {
ExtractedSegments extract_segments(const std::vector<int64_t>& tokens,
const ov::genai::WhisperGenerationConfig& config,
const size_t nb_max_frames,
const float time_precision);
const float time_precision,
const float time_offset = 0.f);

} // namespace genai
} // namespace ov
10 changes: 9 additions & 1 deletion src/cpp/src/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,14 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
const float time_precision = static_cast<float>(feature_extractor.chunk_length) / model_config.max_source_positions;
size_t segment_offset = 0;

OPENVINO_ASSERT(feature_extractor.sampling_rate != 0, "Sampling Rate for Feature Extractor is 0");
const float frame_length_in_seconds =
static_cast<float>(feature_extractor.hop_length) / feature_extractor.sampling_rate;

for (size_t chunk_offset = 0; chunk_offset < input_features.n_frames; chunk_offset += segment_offset) {

const float chunk_time_offset = chunk_offset * frame_length_in_seconds;

auto input_features_chunk = input_features.get_data_with_offset(chunk_offset, feature_extractor.nb_max_frames);

ov::Tensor hidden_state_tensor = encode(encoder,
Expand Down Expand Up @@ -330,7 +337,8 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
auto extracted_segments = ov::genai::extract_segments(chunk_output_tokens,
config,
feature_extractor.nb_max_frames,
time_precision);
time_precision,
chunk_time_offset);

utils::filter_non_segment_metrics(raw_metrics, output_tokens.size(), extracted_segments.segment_ranges);

Expand Down
12 changes: 8 additions & 4 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,10 +945,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
static_cast<float>(m_feature_extractor.chunk_length) / m_model_config.max_source_positions;
size_t segment_offset = 0;

OPENVINO_ASSERT(m_feature_extractor.sampling_rate != 0, "Sampling Rate for Feature Extractor is 0");
const float frame_length_in_seconds =
static_cast<float>(m_feature_extractor.hop_length) / m_feature_extractor.sampling_rate;

for (size_t chunk_offset = 0; chunk_offset < input_features.n_frames; chunk_offset += segment_offset) {
if (output_tokens.size() >= max_new_tokens) {
break;
}

const float chunk_time_offset = chunk_offset * frame_length_in_seconds;

auto input_features_chunk =
input_features.get_data_with_offset(chunk_offset, m_feature_extractor.nb_max_frames);
Expand Down Expand Up @@ -985,7 +988,8 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
auto extracted_segments = ov::genai::extract_segments(chunk_output_tokens,
config,
m_feature_extractor.nb_max_frames,
time_precision);
time_precision,
chunk_time_offset);

ov::genai::utils::filter_non_segment_metrics(raw_metrics, output_tokens.size(), extracted_segments.segment_ranges);

Expand Down
Loading