|
| 1 | +// Copyright (C) 2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +#include "with_past_decoder.hpp" |
| 5 | + |
| 6 | +#include <regex> |
| 7 | + |
| 8 | +#include "logger.hpp" |
| 9 | +#include "utils.hpp" |
| 10 | + |
| 11 | +namespace { |
| 12 | +void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) { |
| 13 | + // source outputs: |
| 14 | + // present.0.decoder.key |
| 15 | + // present.0.decoder.value |
| 16 | + // present.0.encoder.key |
| 17 | + // present.0.encoder.value |
| 18 | + |
| 19 | + // dest inputs: |
| 20 | + // past_key_values.0.decoder.key |
| 21 | + // past_key_values.0.decoder.value |
| 22 | + // past_key_values.0.encoder.key |
| 23 | + // past_key_values.0.encoder.value |
| 24 | + |
| 25 | + for (auto& source_output : source.get_compiled_model().outputs()) { |
| 26 | + std::string source_output_name = source_output.get_any_name(); |
| 27 | + if (source_output_name.find("logits") != std::string::npos) { |
| 28 | + continue; |
| 29 | + } |
| 30 | + |
| 31 | + std::string with_past_input_name = |
| 32 | + std::regex_replace(source_output_name, std::regex("present"), "past_key_values"); |
| 33 | + |
| 34 | + auto kv_tensor = source.get_tensor(source_output_name); |
| 35 | + dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor}); |
| 36 | + } |
| 37 | +} |
| 38 | +} // namespace |
| 39 | + |
| 40 | +namespace ov::genai { |
| 41 | +WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& models_path, |
| 42 | + const std::string& device, |
| 43 | + const ov::AnyMap& properties) { |
| 44 | + Logger::warn("Whisper decoder models with past is deprecated. Support will be removed in 2026.0.0 release.\n" |
| 45 | + "To obtain stateful decoder model use latest `optimum-intel` package:\n" |
| 46 | + "pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git\n" |
| 47 | + "optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny"); |
| 48 | + ov::Core core = utils::singleton_core(); |
| 49 | + |
| 50 | + auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties); |
| 51 | + utils::print_compiled_model_properties(compiled_model, "whisper decoder model"); |
| 52 | + m_request_decoder = compiled_model.create_infer_request(); |
| 53 | + |
| 54 | + compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties); |
| 55 | + utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model"); |
| 56 | + m_request_decoder_with_past = compiled_model.create_infer_request(); |
| 57 | +} |
| 58 | + |
| 59 | +std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state, |
| 60 | + const int64_t decoder_start_token_id) { |
| 61 | + auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0); |
| 62 | + |
| 63 | + int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); |
| 64 | + |
| 65 | + reset_state(); |
| 66 | + |
| 67 | + return {output_token, infer_ms}; |
| 68 | +} |
| 69 | + |
| 70 | +std::pair<ov::Tensor, float> WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state, |
| 71 | + const std::vector<int64_t>& input_ids, |
| 72 | + const size_t cache_position) { |
| 73 | + const bool initial_step = cache_position == 0; |
| 74 | + ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past; |
| 75 | + |
| 76 | + request.set_tensor("encoder_hidden_states", encoder_hidden_state); |
| 77 | + |
| 78 | + const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()); |
| 79 | + request.set_tensor("input_ids", input_ids_tensor); |
| 80 | + |
| 81 | + if (!initial_step) { |
| 82 | + ov::Tensor cache_position_tensor = request.get_tensor("cache_position"); |
| 83 | + cache_position_tensor.set_shape({1}); |
| 84 | + cache_position_tensor.data<int64_t>()[0] = cache_position; |
| 85 | + } |
| 86 | + |
| 87 | + const auto infer_start = std::chrono::steady_clock::now(); |
| 88 | + request.infer(); |
| 89 | + const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); |
| 90 | + |
| 91 | + auto output_tensor = request.get_tensor("logits"); |
| 92 | + |
| 93 | + if (initial_step) { |
| 94 | + set_past_key_value(m_request_decoder, m_request_decoder_with_past); |
| 95 | + } else if (!m_decoder_with_past_kv_value_set) { |
| 96 | + set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past); |
| 97 | + m_decoder_with_past_kv_value_set = true; |
| 98 | + } |
| 99 | + |
| 100 | + return {output_tensor, infer_ms}; |
| 101 | +} |
| 102 | + |
| 103 | +void WhisperWithPastDecoder::reset_state() { |
| 104 | + m_request_decoder_with_past.reset_state(); |
| 105 | + m_decoder_with_past_kv_value_set = false; |
| 106 | +} |
| 107 | +} // namespace ov::genai |
0 commit comments