|
4 | 4 | #include "make_tokenizer_stateful.hpp"
|
5 | 5 | #include "openvino/op/constant.hpp"
|
6 | 6 | #include "openvino/op/select.hpp"
|
| 7 | +#include "openvino/op/maximum.hpp" |
| 8 | +#include "openvino/op/minimum.hpp" |
| 9 | +#include "openvino/op/add.hpp" |
| 10 | +#include "openvino/op/subtract.hpp" |
7 | 11 | #include "openvino/op/slice.hpp"
|
8 | 12 | #include "openvino/op/multiply.hpp"
|
9 | 13 | #include "openvino/op/read_value.hpp"
|
|
13 | 17 | using namespace ov;
|
14 | 18 | using namespace ov::op;
|
15 | 19 |
|
16 |
| -bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) { |
| 20 | +bool ov::genai::MakeAddSpecialTokensSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) { |
17 | 21 | std::shared_ptr<ov::Node> combine_seg_node;
|
18 | 22 | for (auto node: model->get_ordered_ops()) {
|
19 | 23 | if (strcmp(node->get_type_info().name, "CombineSegments") == 0) {
|
@@ -56,6 +60,7 @@ bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr
|
56 | 60 | return true;
|
57 | 61 | }
|
58 | 62 |
|
| 63 | + |
59 | 64 | bool ov::genai::MakeVocabDecoderSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
60 | 65 | std::shared_ptr<ov::Node> vocab_decoder_node;
|
61 | 66 | for (auto node: model->get_ordered_ops()) {
|
@@ -97,3 +102,118 @@ bool ov::genai::MakeVocabDecoderSatateful::run_on_model(const std::shared_ptr<ov
|
97 | 102 | model->add_variables({variable});
|
98 | 103 | return true;
|
99 | 104 | }
|
| 105 | + |
| 106 | + |
| 107 | +bool ov::genai::MakePaddingSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) { |
| 108 | + std::shared_ptr<ov::Node> combine_seg_node; |
| 109 | + for (auto node: model->get_ordered_ops()) { |
| 110 | + if (strcmp(node->get_type_info().name, "CombineSegments") == 0) { |
| 111 | + combine_seg_node = node; |
| 112 | + } |
| 113 | + } |
| 114 | + if (!combine_seg_node) { return false; } |
| 115 | + auto num_comb = combine_seg_node->get_input_size(); |
| 116 | + |
| 117 | + size_t num_segments = (combine_seg_node->get_input_size() - 1) / 3; |
| 118 | + size_t number_of_main_tokens_inputs = 0; |
| 119 | + std::shared_ptr<Node> add_or_sub_node; |
| 120 | + for (size_t i = 0; i < num_segments; i++) { |
| 121 | + // Check all ends inputs of CombineSegments node. |
| 122 | + // For special tokens they are Constant/Select, |
| 123 | + // for the ends input with main tokens sequence it's Add/Subtract. |
| 124 | + // If Add then it's a right truncation, if Subtract then it's a left truncation. |
| 125 | + auto tmp_node = combine_seg_node->input_value(3*i + 1).get_node_shared_ptr(); |
| 126 | + if (ov::as_type_ptr<v1::Add>(tmp_node) || ov::as_type_ptr<v1::Subtract>(tmp_node)) { |
| 127 | + number_of_main_tokens_inputs += 1; |
| 128 | + add_or_sub_node = tmp_node; |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + // Exit if couldn't find main input or there are several. |
| 133 | + if (number_of_main_tokens_inputs != 1) { return false; } |
| 134 | + |
| 135 | + // Minimum between max_length and length of token sequence. |
| 136 | + auto min_node = ov::as_type_ptr<v1::Minimum>(add_or_sub_node->get_input_node_shared_ptr(1)); |
| 137 | + if (!min_node) { return false; } |
| 138 | + |
| 139 | + // constant containing final max_length - num_added tokens at the end of pipeline. |
| 140 | + auto const_node = ov::as_type_ptr<v0::Constant>(min_node->get_input_node_shared_ptr(1)); |
| 141 | + if (!const_node) { return false; } |
| 142 | + |
| 143 | + op::util::VariableInfo var_info{const_node->get_output_shape(0), const_node->get_output_element_type(0), MAX_LENGTH_VAR_ID}; |
| 144 | + auto variable_1 = std::make_shared<op::util::Variable>(var_info); |
| 145 | + |
| 146 | + size_t num_added_tokens = num_segments - number_of_main_tokens_inputs; |
| 147 | + // Constant which stores number of added_tokens. |
| 148 | + auto num_added_tokens_const = std::make_shared<v0::Constant>( |
| 149 | + const_node->get_output_element_type(0), const_node->get_output_shape(0), std::vector{num_added_tokens}); |
| 150 | + |
| 151 | + OPENVINO_ASSERT(const_node->get_element_type() == element::i32); |
| 152 | + auto values = const_node->get_vector<int32_t>(); |
| 153 | + OPENVINO_ASSERT(values.size() == 1); |
| 154 | + // Since const_node contain value = max_length - num_added tokens, |
| 155 | + size_t default_max_length = values[0] + num_added_tokens; |
| 156 | + |
| 157 | + auto default_max_length_const = std::make_shared<v0::Constant>( |
| 158 | + const_node->get_output_element_type(0), const_node->get_output_shape(0), std::vector{default_max_length}); |
| 159 | + |
| 160 | + // Save targets before adding new target with ReadValue to avoid recursion. |
| 161 | + auto target_inputs = const_node->output(0).get_target_inputs(); |
| 162 | + auto max_length_rv = std::make_shared<v6::ReadValue>(default_max_length_const, variable_1); |
| 163 | + auto subtract_node = std::make_shared<v1::Subtract>(max_length_rv, num_added_tokens_const); |
| 164 | + |
| 165 | + for (auto target_input : target_inputs) { |
| 166 | + target_input.replace_source_output(subtract_node->output(0)); |
| 167 | + } |
| 168 | + |
| 169 | + // We need to check if user requested to not add special tokens. |
| 170 | + std::shared_ptr<v6::ReadValue> read_value_spec_tokens; |
| 171 | + for (const auto& sink : model->get_sinks()) { |
| 172 | + // Check if sink accepts input from Assign, and if that't the case get the ReadValus node input. |
| 173 | + if (auto read_value = ov::as_type_ptr<v6::ReadValue>(sink->get_input_node_shared_ptr(0))) { |
| 174 | + if (read_value->get_variable()->get_info().variable_id == ADD_SPECIAL_TOKENS_VAR_ID) { |
| 175 | + read_value_spec_tokens = read_value; |
| 176 | + break; |
| 177 | + } |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + // If user requested to not add special tokens in order to correctly calculate |
| 182 | + // truncation we need to enforce num_added_tokens to 0 regardless the hardcoded value of Constant. |
| 183 | + if (read_value_spec_tokens && num_added_tokens_const) { |
| 184 | + auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0}); |
| 185 | + auto select_node = std::make_shared<v1::Select>(read_value_spec_tokens, num_added_tokens_const, zero_constant); |
| 186 | + subtract_node->input(1).replace_source_output(select_node->output(0)); |
| 187 | + } |
| 188 | + |
| 189 | + model->add_sinks({std::make_shared<v6::Assign>(max_length_rv, variable_1)}); |
| 190 | + model->add_variables({variable_1}); |
| 191 | + |
| 192 | + std::shared_ptr<ov::Node> ragged_to_dense_node; |
| 193 | + for (auto node: model->get_ordered_ops()) { |
| 194 | + if (strcmp(node->get_type_info().name, "RaggedToDense") == 0) { |
| 195 | + ragged_to_dense_node = node; |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + if (!ragged_to_dense_node || ragged_to_dense_node->input_value(3).get_element_type() != ov::element::i32) { |
| 200 | + return true; // true since at this point we already have modified the graph.s |
| 201 | + } |
| 202 | + |
| 203 | + auto variable_2 = std::make_shared<op::util::Variable>(op::util::VariableInfo{ov::Shape{1}, ov::element::boolean, PAD_TO_LONGEST_VAR_ID}); |
| 204 | + |
| 205 | + // By default do not pad to max_length |
| 206 | + auto default_false_const = std::make_shared<v0::Constant>(ov::element::boolean, ov::Shape{1}, std::vector{false}); |
| 207 | + auto pad_to_max_length_rv = std::make_shared<v6::ReadValue>(default_false_const, variable_2); |
| 208 | + |
| 209 | + auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0}); |
| 210 | + auto select_node = std::make_shared<v1::Select>(pad_to_max_length_rv, max_length_rv, zero_constant); |
| 211 | + |
| 212 | + auto max_op = std::make_shared<v1::Maximum>(ragged_to_dense_node->input_value(3), select_node); |
| 213 | + ragged_to_dense_node->input(3).replace_source_output(max_op->output(0)); |
| 214 | + |
| 215 | + model->add_sinks({std::make_shared<v6::Assign>(pad_to_max_length_rv, variable_2)}); |
| 216 | + model->add_variables({variable_2}); |
| 217 | + |
| 218 | + return true; |
| 219 | +} |
0 commit comments