26
26
#include " openvino/op/convert.hpp"
27
27
#include " openvino/op/parameter.hpp"
28
28
29
+ #include " openvino/pass/stateful_to_stateless.hpp"
30
+ #include " openvino/op/shape_of.hpp"
31
+ #include " openvino/opsets/opset13.hpp"
32
+
29
33
using ov::genai::MicroSeconds;
30
34
31
35
namespace {
@@ -498,6 +502,306 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
498
502
return model;
499
503
}
500
504
505
+ void set_name (std::shared_ptr<ov::Node> result, const std::string& name) {
506
+ result->set_friendly_name (name);
507
+ result->get_output_tensor (0 ).set_names ({name});
508
+ }
509
+
510
+ void remove_input_kv_tensors (std::shared_ptr<ov::Model>& model) {
511
+ const int CONCAT_CURR_KV_PORT = 1 ;
512
+
513
+ ov::ParameterVector params_to_remove;
514
+ ov::ResultVector results_to_add;
515
+ ov::ResultVector results_to_remove;
516
+
517
+ std::shared_ptr<ov::op::v0::Constant> cst_node;
518
+ for (const auto & input: model->inputs ()) {
519
+ if (input.get_any_name ().find (" past_key_values" ) == std::string::npos) {
520
+ continue ;
521
+ }
522
+
523
+ // Remember this to remove later on when all reconnections performed
524
+ params_to_remove.push_back (ov::as_type_ptr<ov::op::v0::Parameter>(input.get_node_shared_ptr ()));
525
+ // KV-cache input tensor is connected with Concat and additionally can be connected with ShapeOf
526
+ std::shared_ptr<ov::Node> concat_node;
527
+ auto target_inputs = input.get_node ()->output (0 ).get_target_inputs ();
528
+ for (const auto & target_input : target_inputs) {
529
+ auto target_node = target_input.get_node ();
530
+ // Get Concat node
531
+ if (strstr (target_node->get_type_name (), " Concat" ) != nullptr ) {
532
+ concat_node = target_node->shared_from_this ();
533
+ }
534
+ }
535
+
536
+ // Remove concat node
537
+ OPENVINO_ASSERT (concat_node);
538
+ auto cat_readers = concat_node->outputs ()[0 ].get_target_inputs ();
539
+
540
+ // Result and SDPA
541
+ OPENVINO_ASSERT (cat_readers.size () == 2 );
542
+ for (const auto & cat_reader : cat_readers) {
543
+ if (strstr (cat_reader.get_node ()->get_type_name (), " Result" ) != nullptr ) {
544
+ auto result_in = cat_reader;
545
+
546
+ // Re-assign Result
547
+ auto result_to_remove = ov::as_type_ptr<ov::op::v0::Result>(result_in.get_node ()->shared_from_this ());
548
+ auto result_to_add = std::make_shared<ov::op::v0::Result>(concat_node->inputs ()[CONCAT_CURR_KV_PORT].get_source_output ());
549
+ set_name (result_to_add, result_to_remove->get_friendly_name ());
550
+
551
+ results_to_remove.push_back (result_to_remove);
552
+ results_to_add.push_back (result_to_add);
553
+ }
554
+ if (strstr (cat_reader.get_node ()->get_type_name (), " ScaledDotProductAttention" ) != nullptr ) {
555
+ auto sdpa_in = cat_reader;
556
+
557
+ // Redirect KV from concat to SDPA
558
+ auto curr_kv = concat_node->inputs ()[CONCAT_CURR_KV_PORT].get_source_output ();
559
+ sdpa_in.replace_source_output (curr_kv);
560
+ }
561
+ }
562
+
563
+ // In case KV-cache also connected with ShapeOf
564
+ if (target_inputs.size () == 2 ) {
565
+ for (const auto & target_in : target_inputs) {
566
+ if (strstr (target_in.get_node ()->get_type_name (), " ShapeOf" ) != nullptr ) {
567
+ auto shapeof_node = ov::as_type_ptr<ov::op::v3::ShapeOf>(target_in.get_node ()->shared_from_this ());
568
+ auto shape = std::vector<size_t >{1 , size_t (input.get_partial_shape ()[1 ].get_length ()), 0 , size_t (input.get_partial_shape ()[3 ].get_length ())};
569
+ cst_node = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4 }, shape);
570
+
571
+ ov::replace_node (shapeof_node, cst_node);
572
+ }
573
+ }
574
+ }
575
+ }
576
+
577
+ OPENVINO_ASSERT (cst_node);
578
+
579
+ for (const auto & r: results_to_remove) {
580
+ model->remove_result (r);
581
+ }
582
+ for (const auto & p: params_to_remove) {
583
+ model->remove_parameter (p);
584
+ }
585
+ model->add_results (results_to_add);
586
+ model->validate_nodes_and_infer_types ();
587
+ }
588
+
589
+ auto remove_encoder_attn_read_value (const std::shared_ptr<ov::Node>& rv_node,
590
+ const ov::Output<ov::Node>& kv_out,
591
+ const ov::Input<ov::Node>& sdpa_in) {
592
+ // Find Assign node
593
+ OPENVINO_ASSERT (rv_node->outputs ().size () == 1 );
594
+ auto rv_out = rv_node->outputs ()[0 ];
595
+ ov::NodeVector rv_readers;
596
+ for (const auto & target_in: rv_out.get_target_inputs ()) {
597
+ rv_readers.push_back (target_in.get_node ()->shared_from_this ());
598
+ }
599
+ // Assign and SDPA
600
+ OPENVINO_ASSERT (rv_readers.size () == 2 );
601
+ auto assign_node = (strstr (rv_readers[0 ]->get_type_name (), " Assign" ) != nullptr ) ? rv_readers[0 ] : rv_readers[1 ];
602
+ OPENVINO_ASSERT (strstr (assign_node->get_type_name (), " Assign" ) != nullptr );
603
+ // Redirect KV-cache tensor to SDPA
604
+ sdpa_in.replace_source_output (kv_out);
605
+ return std::make_pair (std::make_shared<ov::op::v0::Result>(kv_out), ov::as_type_ptr<ov::op::v6::Assign>(assign_node));
606
+ }
607
+
608
+ std::string transform_key_value_name (std::string input_string, std::string prefix, std::string enc_or_dec, std::string key_or_value) {
609
+ std::regex pattern (" [0-9]+" );
610
+ std::smatch match;
611
+ std::regex_search (input_string, match, pattern);
612
+
613
+ if (match.empty ())
614
+ OPENVINO_THROW (" Input string does not match the expected pattern" );
615
+
616
+ auto number = std::string (match[0 ]);
617
+ return prefix + " ." + number + enc_or_dec + key_or_value;
618
+ }
619
+
620
+ void expose_runtime_states_as_outputs (std::shared_ptr<ov::Model>& model) {
621
+ // Find all ReadValue nodes
622
+ ov::NodeVector read_value_nodes;
623
+ for (const auto & op : model->get_ops ()) {
624
+ if (strstr (op->get_type_name (), " ReadValue" ) != nullptr ) {
625
+ read_value_nodes.push_back (op);
626
+ }
627
+ }
628
+
629
+ // Holds result layers for cross-attn KV-cache tensors
630
+ ov::ResultVector results;
631
+ ov::SinkVector assigns;
632
+
633
+ // Go through all ReadValue nodes and remove them
634
+ for (const auto & rv_node : read_value_nodes) {
635
+ OPENVINO_ASSERT (rv_node->inputs ().size () == 1 );
636
+ OPENVINO_ASSERT (rv_node->outputs ().size () == 1 );
637
+ auto rv_in = rv_node->inputs ()[0 ];
638
+ auto x = rv_in.get_source_output ();
639
+ auto rv_out = rv_node->outputs ()[0 ];
640
+ // Gather all nodes that read from ReadValue, there must be SDPA and Assign
641
+ auto rv_readers = rv_out.get_target_inputs ();
642
+ OPENVINO_ASSERT (rv_readers.size () == 2 );
643
+ // Input port for SDPA node
644
+ for (const auto & reader : rv_readers) {
645
+ if (strstr (reader.get_node ()->get_type_name (), " ScaledDotProductAttention" ) != nullptr ) {
646
+ auto sdpa_in = reader;
647
+ // Remove ReadValue, store new Result and Assign
648
+ auto key_or_value = (sdpa_in.get_index () == 1 ) ? " key" : " value" ;
649
+ auto [result, assign] = remove_encoder_attn_read_value (rv_node, rv_in.get_source_output (), sdpa_in);
650
+ auto normalized_name = transform_key_value_name (
651
+ rv_node->inputs ()[0 ].get_source_output ().get_node ()->get_friendly_name (),
652
+ " present" ,
653
+ " .encoder." ,
654
+ key_or_value
655
+ );
656
+ set_name (result, normalized_name);
657
+ results.push_back (result);
658
+ assigns.push_back (assign);
659
+ }
660
+ }
661
+ }
662
+
663
+ // Add, remove, validate
664
+ model->add_results (results);
665
+ for (const auto & assign : assigns) {
666
+ model->remove_sink (assign);
667
+ }
668
+ model->validate_nodes_and_infer_types ();
669
+ }
670
+
671
+ void remove_cache_position (std::shared_ptr<ov::Model>& model) {
672
+ // Build subgraph that will replace cache_pos
673
+ auto input_ids = model->input (" input_ids" ).get_node ();
674
+ auto shape_of_node = std::make_shared<ov::op::v3::ShapeOf>(input_ids->outputs ()[0 ]);
675
+
676
+ std::vector<int > v_0{0 };
677
+ std::vector<int > v_1{1 };
678
+
679
+ auto indices = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_1);
680
+ indices->set_friendly_name (" indices" );
681
+ auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_0);
682
+ axis->set_friendly_name (" axis" );
683
+
684
+ auto gather_node = std::make_shared<ov::op::v8::Gather>(shape_of_node->outputs ()[0 ], indices, axis);
685
+
686
+ auto cst_node = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_0);
687
+ auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, v_1);
688
+ step->set_friendly_name (" step" );
689
+ auto range_node = std::make_shared<ov::op::v4::Range>(cst_node->outputs ()[0 ], gather_node->outputs ()[0 ], step->outputs ()[0 ], ov::element::i64);
690
+ // Replace cache_position
691
+ auto cache_pos = ov::as_type_ptr<ov::op::v0::Parameter>(model->input (" cache_position" ).get_node ()->shared_from_this ());
692
+ for (const auto & target_input : cache_pos->outputs ()[0 ].get_target_inputs ()) {
693
+ target_input.replace_source_output (range_node->outputs ()[0 ]);
694
+ }
695
+
696
+ model->remove_parameter (cache_pos);
697
+ model->validate_nodes_and_infer_types ();
698
+ }
699
+
700
+ void normalize_input_key_value_names (std::shared_ptr<ov::Model>& model) {
701
+ ov::ResultVector new_results, old_results;
702
+ for (const auto & in : model->inputs ()) {
703
+ if (in.get_any_name ().find (" decoder" ) == std::string::npos) {
704
+ continue ;
705
+ }
706
+
707
+ auto key_or_value = (in.get_any_name ().find (" .key" ) != std::string::npos) ? " key" : " value" ;
708
+ auto normalized_name = transform_key_value_name (in.get_any_name (), " past_key_values" , " .decoder." , key_or_value);
709
+ set_name (in.get_node_shared_ptr (), normalized_name);
710
+ }
711
+
712
+ model->validate_nodes_and_infer_types ();
713
+ }
714
+
715
+ void normalize_output_key_value_names (std::shared_ptr<ov::Model>& model) {
716
+ ov::ResultVector new_results, old_results;
717
+ for (const auto & out : model->outputs ()) {
718
+ if (out.get_any_name ().find (" decoder" ) == std::string::npos) {
719
+ continue ;
720
+ }
721
+
722
+ auto key_or_value = (out.get_any_name ().find (" .key" ) != std::string::npos) ? " key" : " value" ;
723
+ auto normalized_name = transform_key_value_name (out.get_any_name (), " present" , " .decoder." , key_or_value);
724
+ set_name (out.get_node_shared_ptr (), normalized_name);
725
+ }
726
+
727
+ model->validate_nodes_and_infer_types ();
728
+ }
729
+
730
+ void expose_runtime_states_as_inputs (std::shared_ptr<ov::Model>& model) {
731
+ // Store Assign nodes to perform remove_sink later on
732
+ ov::SinkVector assigns;
733
+ // To add new Params to the model
734
+ ov::ParameterVector params;
735
+
736
+ ov::NodeVector read_value_nodes;
737
+ for (const auto & op : model->get_ops ()) {
738
+ if (strstr (op->get_type_name (), " ReadValue" ) != nullptr ) {
739
+ read_value_nodes.push_back (op);
740
+ }
741
+ }
742
+
743
+ for (const auto & rv_node : read_value_nodes) {
744
+ auto rv_out = rv_node->outputs ()[0 ];
745
+ auto rv_readers = rv_out.get_target_inputs ();
746
+ for (auto rv_reader: rv_readers) {
747
+ if (strstr (rv_reader.get_node ()->get_type_name (), " Assign" ) != nullptr ) {
748
+ auto assign_node = ov::as_type_ptr<ov::op::v6::Assign>(rv_reader.get_node ()->shared_from_this ());
749
+ assigns.push_back (assign_node);
750
+ } else if (strstr (rv_reader.get_node ()->get_type_name (), " ScaledDotProductAttention" ) != nullptr ) {
751
+ auto sdpa_in = rv_reader;
752
+ auto sdpa_node = rv_reader.get_node ();
753
+
754
+ auto shape = rv_node->get_output_partial_shape (0 );
755
+ auto new_param = std::make_shared<ov::op::v0::Parameter>(rv_node->get_output_element_type (0 ), shape);
756
+
757
+ auto key_or_value = (sdpa_in.get_index () == 1 ) ? " key" : " value" ;
758
+ auto normalized_name = transform_key_value_name (sdpa_in.get_node ()->get_friendly_name (),
759
+ " past_key_values" ,
760
+ " .encoder." ,
761
+ key_or_value);
762
+ set_name (new_param, normalized_name);
763
+
764
+ params.push_back (new_param);
765
+ sdpa_in.replace_source_output (new_param->outputs ()[0 ]);
766
+ }
767
+ }
768
+ }
769
+
770
+ // Remove sinks and add new params
771
+ model->add_parameters (params);
772
+ for (const auto & assign: assigns) {
773
+ model->remove_sink (assign);
774
+ }
775
+ }
776
+
777
+ std::shared_ptr<ov::Model> prepare_decoder_model (std::shared_ptr<ov::Model>& model) {
778
+ auto decoder_model = model->clone ();
779
+ // 2) Remove all non-runtime states from inputs (they empty on first iteration)
780
+ remove_input_kv_tensors (decoder_model);
781
+ // 3) Expose all states that requires initialization on the first run as outputs
782
+ expose_runtime_states_as_outputs (decoder_model);
783
+ // 4) Remove cache_position input
784
+ remove_cache_position (decoder_model);
785
+ // 5) Normalize output names - should be done in stateful_to_stateless_transformation
786
+ normalize_output_key_value_names (decoder_model);
787
+
788
+ decoder_model->validate_nodes_and_infer_types ();
789
+ return decoder_model;
790
+ }
791
+
792
+ std::shared_ptr<ov::Model> prepare_decoder_with_past_model (std::shared_ptr<ov::Model>& model) {
793
+ auto decoder_with_past_model = model->clone ();
794
+ // FIXME: normalization should be done inside stateful_to_stateless_transformation
795
+ normalize_input_key_value_names (decoder_with_past_model);
796
+ normalize_output_key_value_names (decoder_with_past_model);
797
+ expose_runtime_states_as_inputs (decoder_with_past_model);
798
+
799
+ decoder_with_past_model->reshape ({{" input_ids" , ov::PartialShape ({-1 , 1 })}});
800
+
801
+ decoder_with_past_model->validate_nodes_and_infer_types ();
802
+ return decoder_with_past_model;
803
+ }
804
+
501
805
} // namespace
502
806
503
807
namespace ov {
@@ -522,8 +826,23 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
522
826
ov::Core core = utils::singleton_core ();
523
827
524
828
auto encoder_model = core.read_model (models_path / " openvino_encoder_model.xml" , {}, properties);
525
- auto decoder_model = core.read_model (models_path / " openvino_decoder_model.xml" , {}, properties);
526
- auto decoder_with_past_model = core.read_model (models_path / " openvino_decoder_with_past_model.xml" , {}, properties);
829
+
830
+ std::shared_ptr<ov::Model> decoder_model;
831
+ std::shared_ptr<ov::Model> decoder_with_past_model;
832
+
833
+ if (std::filesystem::exists (models_path / " openvino_decoder_with_past_model.xml" ) ) {
834
+ decoder_model = core.read_model (models_path / " openvino_decoder_model.xml" , {}, properties);
835
+ decoder_with_past_model = core.read_model (models_path / " openvino_decoder_with_past_model.xml" , {}, properties);
836
+ } else {
837
+ auto model = core.read_model (models_path / " openvino_decoder_model.xml" , {}, properties);
838
+ ov::pass::StatefulToStateless ().run_on_model (model);
839
+
840
+ decoder_model = prepare_decoder_model (model);
841
+ decoder_with_past_model = prepare_decoder_with_past_model (model);
842
+ }
843
+
844
+ if (!decoder_model || !decoder_with_past_model)
845
+ OPENVINO_THROW (" Decoder/decoder_with_past model is not valid !" );
527
846
528
847
add_attention_mask_input (decoder_with_past_model);
529
848
0 commit comments