|
5 | 5 |
|
6 | 6 | #include <fstream>
|
7 | 7 |
|
| 8 | +#include "openvino/op/add.hpp" |
| 9 | +#include "openvino/op/divide.hpp" |
| 10 | +#include "openvino/op/multiply.hpp" |
| 11 | +#include "openvino/op/matmul.hpp" |
| 12 | +#include "openvino/op/slice.hpp" |
| 13 | +#include "openvino/op/tanh.hpp" |
| 14 | +#include "openvino/op/transpose.hpp" |
| 15 | + |
8 | 16 | namespace ov {
|
9 | 17 | namespace genai {
|
10 | 18 | namespace utils {
|
@@ -225,6 +233,32 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token
|
225 | 233 |
|
226 | 234 | return {new_input_ids, new_attention_mask};
|
227 | 235 | }
|
| 236 | + |
| 237 | +void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model) { |
| 238 | + ov::Node* matmul = nullptr; |
| 239 | + auto last_node = model->output(0).get_node()->input_value(0).get_node(); |
| 240 | + if (matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node)) { |
| 241 | + } else if(auto add = dynamic_cast<ov::op::v1::Add*>(last_node)) { |
| 242 | + matmul = dynamic_cast<ov::op::v0::MatMul*>(add->input_value(0).get_node()); |
| 243 | + } else if (auto transpose = dynamic_cast<ov::op::v1::Transpose*>(last_node)) { |
| 244 | + matmul = dynamic_cast<ov::op::v0::MatMul*>(transpose->input_value(0).get_node()); |
| 245 | + } else if (auto multiply = dynamic_cast<ov::op::v1::Multiply*>(last_node)) { |
| 246 | + if (auto tanh = dynamic_cast<ov::op::v0::Tanh*>(multiply->input_value(0).get_node())) { |
| 247 | + if (auto divide = dynamic_cast<ov::op::v1::Divide*>(tanh->input_value(0).get_node())) { |
| 248 | + matmul = dynamic_cast<ov::op::v0::MatMul*>(divide->input_value(0).get_node()); |
| 249 | + } |
| 250 | + } |
| 251 | + } |
| 252 | + |
| 253 | + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { |
| 254 | + auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1}); |
| 255 | + auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-2}); |
| 256 | + auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1}); |
| 257 | + auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}); |
| 258 | + auto slice = std::make_shared<ov::op::v8::Slice>(matmul->input_value(0), start, stop, step, axis); |
| 259 | + matmul->input(0).replace_source_output(slice); |
| 260 | + } |
| 261 | +} |
228 | 262 | } // namespace utils
|
229 | 263 | } // namespace genai
|
230 | 264 | } // namespace ov
|
0 commit comments