Skip to content

Commit 50c45f0

Browse files
authored
Static Whisper: transformations for dual-model whisper (#1820)
1 parent d19ba91 commit 50c45f0

File tree

2 files changed

+389
-2
lines changed

2 files changed

+389
-2
lines changed

src/cpp/src/whisper_pipeline_static.cpp

+321-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
#include "openvino/op/convert.hpp"
2727
#include "openvino/op/parameter.hpp"
2828

29+
#include "openvino/pass/stateful_to_stateless.hpp"
30+
#include "openvino/op/shape_of.hpp"
31+
#include "openvino/opsets/opset13.hpp"
32+
2933
using ov::genai::MicroSeconds;
3034

3135
namespace {
@@ -498,6 +502,306 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
498502
return model;
499503
}
500504

505+
void set_name(std::shared_ptr<ov::Node> result, const std::string& name) {
506+
result->set_friendly_name(name);
507+
result->get_output_tensor(0).set_names({name});
508+
}
509+
510+
void remove_input_kv_tensors(std::shared_ptr<ov::Model>& model) {
511+
const int CONCAT_CURR_KV_PORT = 1;
512+
513+
ov::ParameterVector params_to_remove;
514+
ov::ResultVector results_to_add;
515+
ov::ResultVector results_to_remove;
516+
517+
std::shared_ptr<ov::op::v0::Constant> cst_node;
518+
for (const auto& input: model->inputs()) {
519+
if (input.get_any_name().find("past_key_values") == std::string::npos) {
520+
continue;
521+
}
522+
523+
// Remember this to remove later on when all reconnections performed
524+
params_to_remove.push_back(ov::as_type_ptr<ov::op::v0::Parameter>(input.get_node_shared_ptr()));
525+
// KV-cache input tensor is connected with Concat and additionally can be connected with ShapeOf
526+
std::shared_ptr<ov::Node> concat_node;
527+
auto target_inputs = input.get_node()->output(0).get_target_inputs();
528+
for (const auto& target_input : target_inputs) {
529+
auto target_node = target_input.get_node();
530+
// Get Concat node
531+
if (strstr(target_node->get_type_name(), "Concat") != nullptr) {
532+
concat_node = target_node->shared_from_this();
533+
}
534+
}
535+
536+
// Remove concat node
537+
OPENVINO_ASSERT(concat_node);
538+
auto cat_readers = concat_node->outputs()[0].get_target_inputs();
539+
540+
// Result and SDPA
541+
OPENVINO_ASSERT(cat_readers.size() == 2);
542+
for (const auto& cat_reader : cat_readers) {
543+
if (strstr(cat_reader.get_node()->get_type_name(), "Result") != nullptr) {
544+
auto result_in = cat_reader;
545+
546+
// Re-assign Result
547+
auto result_to_remove = ov::as_type_ptr<ov::op::v0::Result>(result_in.get_node()->shared_from_this());
548+
auto result_to_add = std::make_shared<ov::op::v0::Result>(concat_node->inputs()[CONCAT_CURR_KV_PORT].get_source_output());
549+
set_name(result_to_add, result_to_remove->get_friendly_name());
550+
551+
results_to_remove.push_back(result_to_remove);
552+
results_to_add.push_back(result_to_add);
553+
}
554+
if (strstr(cat_reader.get_node()->get_type_name(), "ScaledDotProductAttention") != nullptr) {
555+
auto sdpa_in = cat_reader;
556+
557+
// Redirect KV from concat to SDPA
558+
auto curr_kv = concat_node->inputs()[CONCAT_CURR_KV_PORT].get_source_output();
559+
sdpa_in.replace_source_output(curr_kv);
560+
}
561+
}
562+
563+
// In case KV-cache also connected with ShapeOf
564+
if (target_inputs.size() == 2) {
565+
for (const auto& target_in : target_inputs) {
566+
if (strstr(target_in.get_node()->get_type_name(), "ShapeOf") != nullptr) {
567+
auto shapeof_node = ov::as_type_ptr<ov::op::v3::ShapeOf>(target_in.get_node()->shared_from_this());
568+
auto shape = std::vector<size_t>{1, size_t(input.get_partial_shape()[1].get_length()), 0, size_t(input.get_partial_shape()[3].get_length())};
569+
cst_node = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, shape);
570+
571+
ov::replace_node(shapeof_node, cst_node);
572+
}
573+
}
574+
}
575+
}
576+
577+
OPENVINO_ASSERT(cst_node);
578+
579+
for (const auto& r: results_to_remove) {
580+
model->remove_result(r);
581+
}
582+
for (const auto& p: params_to_remove) {
583+
model->remove_parameter(p);
584+
}
585+
model->add_results(results_to_add);
586+
model->validate_nodes_and_infer_types();
587+
}
588+
589+
auto remove_encoder_attn_read_value(const std::shared_ptr<ov::Node>& rv_node,
590+
const ov::Output<ov::Node>& kv_out,
591+
const ov::Input<ov::Node>& sdpa_in) {
592+
// Find Assign node
593+
OPENVINO_ASSERT(rv_node->outputs().size() == 1);
594+
auto rv_out = rv_node->outputs()[0];
595+
ov::NodeVector rv_readers;
596+
for (const auto& target_in: rv_out.get_target_inputs()) {
597+
rv_readers.push_back(target_in.get_node()->shared_from_this());
598+
}
599+
// Assign and SDPA
600+
OPENVINO_ASSERT(rv_readers.size() == 2);
601+
auto assign_node = (strstr(rv_readers[0]->get_type_name(), "Assign") != nullptr) ? rv_readers[0] : rv_readers[1];
602+
OPENVINO_ASSERT(strstr(assign_node->get_type_name(), "Assign") != nullptr);
603+
// Redirect KV-cache tensor to SDPA
604+
sdpa_in.replace_source_output(kv_out);
605+
return std::make_pair(std::make_shared<ov::op::v0::Result>(kv_out), ov::as_type_ptr<ov::op::v6::Assign>(assign_node));
606+
}
607+
608+
std::string transform_key_value_name(std::string input_string, std::string prefix, std::string enc_or_dec, std::string key_or_value) {
609+
std::regex pattern("[0-9]+");
610+
std::smatch match;
611+
std::regex_search(input_string, match, pattern);
612+
613+
if (match.empty())
614+
OPENVINO_THROW("Input string does not match the expected pattern");
615+
616+
auto number = std::string(match[0]);
617+
return prefix + "." + number + enc_or_dec + key_or_value;
618+
}
619+
620+
void expose_runtime_states_as_outputs(std::shared_ptr<ov::Model>& model) {
621+
// Find all ReadValue nodes
622+
ov::NodeVector read_value_nodes;
623+
for (const auto& op : model->get_ops()) {
624+
if (strstr(op->get_type_name(), "ReadValue") != nullptr) {
625+
read_value_nodes.push_back(op);
626+
}
627+
}
628+
629+
// Holds result layers for cross-attn KV-cache tensors
630+
ov::ResultVector results;
631+
ov::SinkVector assigns;
632+
633+
// Go through all ReadValue nodes and remove them
634+
for (const auto& rv_node : read_value_nodes) {
635+
OPENVINO_ASSERT(rv_node->inputs().size() == 1);
636+
OPENVINO_ASSERT(rv_node->outputs().size() == 1);
637+
auto rv_in = rv_node->inputs()[0];
638+
auto x = rv_in.get_source_output();
639+
auto rv_out = rv_node->outputs()[0];
640+
// Gather all nodes that read from ReadValue, there must be SDPA and Assign
641+
auto rv_readers = rv_out.get_target_inputs();
642+
OPENVINO_ASSERT(rv_readers.size() == 2);
643+
// Input port for SDPA node
644+
for (const auto& reader : rv_readers) {
645+
if (strstr(reader.get_node()->get_type_name(), "ScaledDotProductAttention") != nullptr) {
646+
auto sdpa_in = reader;
647+
// Remove ReadValue, store new Result and Assign
648+
auto key_or_value = (sdpa_in.get_index() == 1) ? "key" : "value";
649+
auto [result, assign] = remove_encoder_attn_read_value(rv_node, rv_in.get_source_output(), sdpa_in);
650+
auto normalized_name = transform_key_value_name(
651+
rv_node->inputs()[0].get_source_output().get_node()->get_friendly_name(),
652+
"present",
653+
".encoder.",
654+
key_or_value
655+
);
656+
set_name(result, normalized_name);
657+
results.push_back(result);
658+
assigns.push_back(assign);
659+
}
660+
}
661+
}
662+
663+
// Add, remove, validate
664+
model->add_results(results);
665+
for (const auto& assign : assigns) {
666+
model->remove_sink(assign);
667+
}
668+
model->validate_nodes_and_infer_types();
669+
}
670+
671+
void remove_cache_position(std::shared_ptr<ov::Model>& model) {
672+
// Build subgraph that will replace cache_pos
673+
auto input_ids = model->input("input_ids").get_node();
674+
auto shape_of_node = std::make_shared<ov::op::v3::ShapeOf>(input_ids->outputs()[0]);
675+
676+
std::vector<int> v_0{0};
677+
std::vector<int> v_1{1};
678+
679+
auto indices = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_1);
680+
indices->set_friendly_name("indices");
681+
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_0);
682+
axis->set_friendly_name("axis");
683+
684+
auto gather_node = std::make_shared<ov::op::v8::Gather>(shape_of_node->outputs()[0], indices, axis);
685+
686+
auto cst_node = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_0);
687+
auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_1);
688+
step->set_friendly_name("step");
689+
auto range_node = std::make_shared<ov::op::v4::Range>(cst_node->outputs()[0], gather_node->outputs()[0], step->outputs()[0], ov::element::i64);
690+
// Replace cache_position
691+
auto cache_pos = ov::as_type_ptr<ov::op::v0::Parameter>(model->input("cache_position").get_node()->shared_from_this());
692+
for (const auto& target_input : cache_pos->outputs()[0].get_target_inputs()) {
693+
target_input.replace_source_output(range_node->outputs()[0]);
694+
}
695+
696+
model->remove_parameter(cache_pos);
697+
model->validate_nodes_and_infer_types();
698+
}
699+
700+
void normalize_input_key_value_names(std::shared_ptr<ov::Model>& model) {
701+
ov::ResultVector new_results, old_results;
702+
for (const auto& in : model->inputs()) {
703+
if (in.get_any_name().find("decoder") == std::string::npos) {
704+
continue;
705+
}
706+
707+
auto key_or_value = (in.get_any_name().find(".key") != std::string::npos) ? "key" : "value";
708+
auto normalized_name = transform_key_value_name(in.get_any_name(), "past_key_values", ".decoder.", key_or_value);
709+
set_name(in.get_node_shared_ptr(), normalized_name);
710+
}
711+
712+
model->validate_nodes_and_infer_types();
713+
}
714+
715+
void normalize_output_key_value_names(std::shared_ptr<ov::Model>& model) {
716+
ov::ResultVector new_results, old_results;
717+
for (const auto& out : model->outputs()) {
718+
if (out.get_any_name().find("decoder") == std::string::npos) {
719+
continue;
720+
}
721+
722+
auto key_or_value = (out.get_any_name().find(".key") != std::string::npos) ? "key" : "value";
723+
auto normalized_name = transform_key_value_name(out.get_any_name(), "present", ".decoder.", key_or_value);
724+
set_name(out.get_node_shared_ptr(), normalized_name);
725+
}
726+
727+
model->validate_nodes_and_infer_types();
728+
}
729+
730+
void expose_runtime_states_as_inputs(std::shared_ptr<ov::Model>& model) {
731+
// Store Assign nodes to perform remove_sink later on
732+
ov::SinkVector assigns;
733+
// To add new Params to the model
734+
ov::ParameterVector params;
735+
736+
ov::NodeVector read_value_nodes;
737+
for (const auto& op : model->get_ops()) {
738+
if (strstr(op->get_type_name(), "ReadValue") != nullptr) {
739+
read_value_nodes.push_back(op);
740+
}
741+
}
742+
743+
for (const auto& rv_node : read_value_nodes) {
744+
auto rv_out = rv_node->outputs()[0];
745+
auto rv_readers = rv_out.get_target_inputs();
746+
for (auto rv_reader: rv_readers) {
747+
if (strstr(rv_reader.get_node()->get_type_name(), "Assign") != nullptr) {
748+
auto assign_node = ov::as_type_ptr<ov::op::v6::Assign>(rv_reader.get_node()->shared_from_this());
749+
assigns.push_back(assign_node);
750+
} else if (strstr(rv_reader.get_node()->get_type_name(), "ScaledDotProductAttention") != nullptr) {
751+
auto sdpa_in = rv_reader;
752+
auto sdpa_node = rv_reader.get_node();
753+
754+
auto shape = rv_node->get_output_partial_shape(0);
755+
auto new_param = std::make_shared<ov::op::v0::Parameter>(rv_node->get_output_element_type(0), shape);
756+
757+
auto key_or_value = (sdpa_in.get_index() == 1) ? "key" : "value";
758+
auto normalized_name = transform_key_value_name(sdpa_in.get_node()->get_friendly_name(),
759+
"past_key_values",
760+
".encoder.",
761+
key_or_value);
762+
set_name(new_param, normalized_name);
763+
764+
params.push_back(new_param);
765+
sdpa_in.replace_source_output(new_param->outputs()[0]);
766+
}
767+
}
768+
}
769+
770+
// Remove sinks and add new params
771+
model->add_parameters(params);
772+
for (const auto& assign: assigns) {
773+
model->remove_sink(assign);
774+
}
775+
}
776+
777+
std::shared_ptr<ov::Model> prepare_decoder_model(std::shared_ptr<ov::Model>& model) {
778+
auto decoder_model = model->clone();
779+
// 2) Remove all non-runtime states from inputs (they empty on first iteration)
780+
remove_input_kv_tensors(decoder_model);
781+
// 3) Expose all states that requires initialization on the first run as outputs
782+
expose_runtime_states_as_outputs(decoder_model);
783+
// 4) Remove cache_position input
784+
remove_cache_position(decoder_model);
785+
// 5) Normalize output names - should be done in stateful_to_stateless_transformation
786+
normalize_output_key_value_names(decoder_model);
787+
788+
decoder_model->validate_nodes_and_infer_types();
789+
return decoder_model;
790+
}
791+
792+
std::shared_ptr<ov::Model> prepare_decoder_with_past_model(std::shared_ptr<ov::Model>& model) {
793+
auto decoder_with_past_model = model->clone();
794+
// FIXME: normalization should be done inside stateful_to_stateless_transformation
795+
normalize_input_key_value_names(decoder_with_past_model);
796+
normalize_output_key_value_names(decoder_with_past_model);
797+
expose_runtime_states_as_inputs(decoder_with_past_model);
798+
799+
decoder_with_past_model->reshape({{"input_ids", ov::PartialShape({-1, 1})}});
800+
801+
decoder_with_past_model->validate_nodes_and_infer_types();
802+
return decoder_with_past_model;
803+
}
804+
501805
} // namespace
502806

503807
namespace ov {
@@ -522,8 +826,23 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
522826
ov::Core core = utils::singleton_core();
523827

524828
auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, properties);
525-
auto decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
526-
auto decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, properties);
829+
830+
std::shared_ptr<ov::Model> decoder_model;
831+
std::shared_ptr<ov::Model> decoder_with_past_model;
832+
833+
if (std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml") ) {
834+
decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
835+
decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, properties);
836+
} else {
837+
auto model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
838+
ov::pass::StatefulToStateless().run_on_model(model);
839+
840+
decoder_model = prepare_decoder_model(model);
841+
decoder_with_past_model = prepare_decoder_with_past_model(model);
842+
}
843+
844+
if (!decoder_model || !decoder_with_past_model)
845+
OPENVINO_THROW("Decoder/decoder_with_past model is not valid !");
527846

528847
add_attention_mask_input(decoder_with_past_model);
529848

0 commit comments

Comments
 (0)