Skip to content

Commit 67d6cd3

Browse files
authored
Whisper pipeline: support stateful decoder (#1474)
Ticket: 159473 Optimum-intel PR: huggingface/optimum-intel#1078 This PR switches optimum-intel in tests to stateful seq2seq branch. Tests check both stateful and with past decoders. Once optimum-intel PR is merged I'll switch version back to master.
1 parent 505abe8 commit 67d6cd3

15 files changed

+477
-214
lines changed

.github/workflows/windows.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,9 @@ jobs:
311311
python -m pip install . --verbose --find-links ${env:OV_INSTALL_DIR}/wheels
312312
python -m pip install ./tools/who_what_benchmark --find-links ${env:OV_INSTALL_DIR}/wheels
313313
314-
# will install transformers 4.46.3 version
315314
# transformers 4.46.3 will enable return_timestamps tests
316315
# this check enabled for windows only. Ticket: 160205.
317-
python -m pip install git+https://github.com/huggingface/optimum-intel.git@753f84db6e0966580eb9eaa74a808213be730631
316+
python -m pip install transformers==4.46.3
318317
319318
python -m pytest -v ./tests/python_tests/test_whisper_pipeline.py -k "not test_smoke"
320319

src/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ int main(int argc, char* argv[]) {
179179

180180
Streaming with a custom class:
181181

182-
C++ template for a stremer.
182+
C++ template for a streamer.
183183
```cpp
184184
#include "openvino/genai/streamer_base.hpp"
185185
#include "openvino/genai/llm_pipeline.hpp"

src/cpp/src/logger.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
#include <iostream>
6+
#include <string>
7+
8+
namespace ov::genai {
9+
10+
class Logger {
11+
public:
12+
static void warn(std::string message) {
13+
std::cout << "[WARN] " << message << '\n';
14+
};
15+
};
16+
17+
} // namespace ov::genai
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "decoder.hpp"
5+
6+
#include <filesystem>
7+
8+
#include "statefull_decoder.hpp"
9+
#include "utils.hpp"
10+
#include "with_past_decoder.hpp"
11+
12+
namespace ov::genai {
13+
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path,
14+
const std::string& device,
15+
const ov::AnyMap& properties) {
16+
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml");
17+
18+
if (has_decoder_with_past) {
19+
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties);
20+
}
21+
22+
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
23+
}
24+
25+
WhisperDecoder::~WhisperDecoder() = default;
26+
} // namespace ov::genai
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include <filesystem>
7+
8+
#include "openvino/genai/whisper_generation_config.hpp"
9+
#include "openvino/runtime/core.hpp"
10+
11+
namespace ov::genai {
12+
class WhisperDecoder {
13+
public:
14+
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path,
15+
const std::string& device,
16+
const ov::AnyMap& properties);
17+
18+
virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
19+
const int64_t decoder_start_token_id) = 0;
20+
21+
virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
22+
const std::vector<int64_t>& input_ids,
23+
const size_t cache_position) = 0;
24+
25+
virtual void reset_state() = 0;
26+
27+
virtual ~WhisperDecoder();
28+
};
29+
} // namespace ov::genai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "statefull_decoder.hpp"
5+
6+
#include "utils.hpp"
7+
8+
namespace ov::genai {
9+
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path,
10+
const std::string& device,
11+
const ov::AnyMap& properties) {
12+
ov::Core core = utils::singleton_core();
13+
14+
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
15+
16+
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
17+
m_request = compiled_model.create_infer_request();
18+
}
19+
20+
std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
21+
const int64_t decoder_start_token_id) {
22+
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);
23+
24+
int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);
25+
26+
reset_state();
27+
28+
return {output_token, infer_ms};
29+
}
30+
31+
std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state,
32+
const std::vector<int64_t>& input_ids,
33+
const size_t cache_position) {
34+
m_request.set_tensor("encoder_hidden_states", encoder_hidden_state);
35+
36+
ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
37+
m_request.set_tensor("input_ids", input_ids_tensor);
38+
39+
ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position");
40+
cache_position_tensor.set_shape({input_ids.size()});
41+
42+
auto cache_data = cache_position_tensor.data<int64_t>();
43+
std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position);
44+
45+
m_request.get_tensor("beam_idx").set_shape({1});
46+
m_request.get_tensor("beam_idx").data<int32_t>()[0] = 0;
47+
48+
const auto infer_start = std::chrono::steady_clock::now();
49+
m_request.infer();
50+
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);
51+
52+
auto output_tensor = m_request.get_tensor("logits");
53+
54+
return {output_tensor, infer_ms};
55+
};
56+
57+
void WhisperStatefullDecoder::reset_state() {
58+
m_request.reset_state();
59+
}
60+
} // namespace ov::genai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include "decoder.hpp"
7+
#include "openvino/runtime/core.hpp"
8+
9+
namespace ov::genai {
10+
11+
class WhisperStatefullDecoder : public WhisperDecoder {
12+
public:
13+
WhisperStatefullDecoder(const std::filesystem::path& models_path,
14+
const std::string& device,
15+
const ov::AnyMap& properties);
16+
17+
std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
18+
const int64_t decoder_start_token_id) override;
19+
20+
std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
21+
const std::vector<int64_t>& input_ids,
22+
const size_t cache_position) override;
23+
24+
void reset_state() override;
25+
26+
private:
27+
ov::InferRequest m_request;
28+
};
29+
} // namespace ov::genai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "with_past_decoder.hpp"
5+
6+
#include <regex>
7+
8+
#include "logger.hpp"
9+
#include "utils.hpp"
10+
11+
namespace {
12+
void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) {
13+
// source outputs:
14+
// present.0.decoder.key
15+
// present.0.decoder.value
16+
// present.0.encoder.key
17+
// present.0.encoder.value
18+
19+
// dest inputs:
20+
// past_key_values.0.decoder.key
21+
// past_key_values.0.decoder.value
22+
// past_key_values.0.encoder.key
23+
// past_key_values.0.encoder.value
24+
25+
for (auto& source_output : source.get_compiled_model().outputs()) {
26+
std::string source_output_name = source_output.get_any_name();
27+
if (source_output_name.find("logits") != std::string::npos) {
28+
continue;
29+
}
30+
31+
std::string with_past_input_name =
32+
std::regex_replace(source_output_name, std::regex("present"), "past_key_values");
33+
34+
auto kv_tensor = source.get_tensor(source_output_name);
35+
dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor});
36+
}
37+
}
38+
} // namespace
39+
40+
namespace ov::genai {
41+
WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& models_path,
42+
const std::string& device,
43+
const ov::AnyMap& properties) {
44+
Logger::warn("Whisper decoder models with past is deprecated. Support will be removed in 2026.0.0 release.\n"
45+
"To obtain stateful decoder model use latest `optimum-intel` package:\n"
46+
"pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git\n"
47+
"optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny");
48+
ov::Core core = utils::singleton_core();
49+
50+
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
51+
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
52+
m_request_decoder = compiled_model.create_infer_request();
53+
54+
compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties);
55+
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");
56+
m_request_decoder_with_past = compiled_model.create_infer_request();
57+
}
58+
59+
std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
60+
const int64_t decoder_start_token_id) {
61+
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);
62+
63+
int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);
64+
65+
reset_state();
66+
67+
return {output_token, infer_ms};
68+
}
69+
70+
std::pair<ov::Tensor, float> WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state,
71+
const std::vector<int64_t>& input_ids,
72+
const size_t cache_position) {
73+
const bool initial_step = cache_position == 0;
74+
ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past;
75+
76+
request.set_tensor("encoder_hidden_states", encoder_hidden_state);
77+
78+
const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
79+
request.set_tensor("input_ids", input_ids_tensor);
80+
81+
if (!initial_step) {
82+
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
83+
cache_position_tensor.set_shape({1});
84+
cache_position_tensor.data<int64_t>()[0] = cache_position;
85+
}
86+
87+
const auto infer_start = std::chrono::steady_clock::now();
88+
request.infer();
89+
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);
90+
91+
auto output_tensor = request.get_tensor("logits");
92+
93+
if (initial_step) {
94+
set_past_key_value(m_request_decoder, m_request_decoder_with_past);
95+
} else if (!m_decoder_with_past_kv_value_set) {
96+
set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past);
97+
m_decoder_with_past_kv_value_set = true;
98+
}
99+
100+
return {output_tensor, infer_ms};
101+
}
102+
103+
void WhisperWithPastDecoder::reset_state() {
104+
m_request_decoder_with_past.reset_state();
105+
m_decoder_with_past_kv_value_set = false;
106+
}
107+
} // namespace ov::genai
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include "decoder.hpp"
7+
#include "openvino/runtime/core.hpp"
8+
9+
namespace ov::genai {
10+
11+
class WhisperWithPastDecoder : public WhisperDecoder {
12+
public:
13+
WhisperWithPastDecoder(const std::filesystem::path& models_path,
14+
const std::string& device,
15+
const ov::AnyMap& properties);
16+
17+
std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
18+
const int64_t decoder_start_token_id) override;
19+
20+
std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
21+
const std::vector<int64_t>& input_ids,
22+
const size_t cache_position) override;
23+
24+
void reset_state() override;
25+
26+
private:
27+
ov::InferRequest m_request_decoder;
28+
ov::InferRequest m_request_decoder_with_past;
29+
bool m_decoder_with_past_kv_value_set = false;
30+
};
31+
32+
} // namespace ov::genai

0 commit comments

Comments
 (0)