Skip to content

Commit 8512fc1

Browse files
Add PushConstantToSubgraph transformation (openvinotoolkit#15250)
* Add PushConstantToSubgraph transformation Transformation detects constfoldable inputs to MultiSubGraphOp, constantfold them and then pushes them to inner subgraphs. Ticket: 98155 * cast to int * comments, split to functions * remove op::util
1 parent 2c64c3a commit 8512fc1

File tree

8 files changed

+464
-0
lines changed

8 files changed

+464
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <openvino/pass/pass.hpp>
8+
#include <transformations_visibility.hpp>
9+
10+
namespace ov {
11+
namespace pass {
12+
13+
/**
14+
* @ingroup ie_transformation_common_api
15+
* @brief PushConstantToSubgraph transformation detects MultiSubGraphOp inputs
16+
* that can be constfoldable pushes that inputs to subgraphs.
17+
*/
18+
class TRANSFORMATIONS_API PushConstantToSubgraph : public ov::pass::ModelPass {
19+
public:
20+
OPENVINO_RTTI("PushConstantToSubgraph", "0");
21+
bool run_on_model(const std::shared_ptr<Model>& model) override;
22+
};
23+
24+
} // namespace pass
25+
} // namespace ov

src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include <transformations/common_optimizations/prelu_fusion.hpp>
4646
#include <transformations/common_optimizations/pull_through_reduce.hpp>
4747
#include <transformations/common_optimizations/pull_transpose_through_fq.hpp>
48+
#include <transformations/common_optimizations/push_constant_to_subgraph.hpp>
4849
#include <transformations/common_optimizations/random_uniform_fusion.hpp>
4950
#include <transformations/common_optimizations/reduce_reshape_fusion.hpp>
5051
#include <transformations/common_optimizations/relu_fake_quantize_fusion.hpp>
@@ -121,6 +122,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
121122
REGISTER_PASS(manager, RemoveMultiSubGraphOpDanglingParams)
122123
REGISTER_PASS(manager, FoldSubgraphEmptyInputs)
123124
REGISTER_PASS(manager, DisableRandomUniformConstantFolding)
125+
REGISTER_PASS(manager, PushConstantToSubgraph)
124126
REGISTER_PASS(manager, ConstantFolding)
125127
REGISTER_PASS(manager, Validate)
126128

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/common_optimizations/push_constant_to_subgraph.hpp"
6+
7+
#include <openvino/core/validation_util.hpp>
8+
#include <openvino/op/util/multi_subgraph_base.hpp>
9+
10+
#include "itt.hpp"
11+
12+
using MultiSubGraphOp = ov::op::util::MultiSubGraphOp;
13+
14+
static std::shared_ptr<ov::op::v0::Constant> try_constantfold_input(
15+
const std::shared_ptr<MultiSubGraphOp>& op,
16+
const MultiSubGraphOp::InputDescription::Ptr& input_desc,
17+
std::unordered_map<size_t, std::shared_ptr<ov::op::v0::Constant>>& cache) {
18+
if (!std::dynamic_pointer_cast<MultiSubGraphOp::InvariantInputDescription>(input_desc)) {
19+
return nullptr;
20+
}
21+
const auto input_index = input_desc->m_input_index;
22+
auto it = cache.find(input_index);
23+
if (it == cache.end()) {
24+
auto constant = constantfold_subgraph(op->input_value(input_index));
25+
if (constant) {
26+
cache.insert({input_index, constant});
27+
}
28+
return constant;
29+
}
30+
return it->second;
31+
}
32+
33+
static void replace_body_parameter(const std::shared_ptr<ov::Model>& body,
34+
const std::shared_ptr<ov::op::v0::Parameter>& body_param,
35+
size_t body_parameter_index,
36+
const std::shared_ptr<ov::op::v0::Constant>& constant,
37+
MultiSubGraphOp::MultiSubgraphInputDescriptionVector& descriptions) {
38+
body_param->output(0).replace(constant);
39+
body->remove_parameter(body_param);
40+
// update all input descriptions to reflect that body parameter was removed
41+
for (auto& desc : descriptions) {
42+
if (desc->m_body_parameter_index > body_parameter_index) {
43+
desc->m_body_parameter_index--;
44+
}
45+
}
46+
}
47+
48+
static void update_multi_sub_graph_op_inputs(const std::shared_ptr<MultiSubGraphOp>& multi_sub_graph_op,
49+
int remove_inputs_mask) {
50+
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());
51+
auto inputs = multi_sub_graph_op->input_values();
52+
for (size_t i = multi_sub_graph_op->get_input_size(); i > 0; i--) {
53+
const auto input_index = i - 1;
54+
if ((remove_inputs_mask & (1 << input_index)) != 0) {
55+
// remove MultiSubGraphOp's input if it was marked to be removed
56+
// (meaning it was constfolded and pushed to inner subgraph)
57+
inputs.erase(inputs.begin() + input_index);
58+
59+
// update input descriptions to reflect that the input was removed
60+
for (int body_idx = 0; body_idx < num_subgraphs; body_idx++) {
61+
auto& descriptions = multi_sub_graph_op->get_input_descriptions(body_idx);
62+
for (auto& desc : descriptions) {
63+
if (desc->m_input_index > input_index) {
64+
desc->m_input_index--;
65+
}
66+
}
67+
}
68+
}
69+
}
70+
multi_sub_graph_op->set_arguments(inputs);
71+
}
72+
73+
bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>& model) {
74+
RUN_ON_FUNCTION_SCOPE(PushConstantToSubgraph);
75+
76+
bool result = false;
77+
for (const auto& op : model->get_ordered_ops()) {
78+
const auto multi_sub_graph_op = as_type_ptr<op::util::MultiSubGraphOp>(op);
79+
if (!multi_sub_graph_op) {
80+
continue;
81+
}
82+
83+
// cache for already constant folded inputs
84+
std::unordered_map<size_t, std::shared_ptr<op::v0::Constant>> cache;
85+
// bitmask describing which MultiSubGraphOp's input to remove
86+
int remove_inputs_mask = 0;
87+
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());
88+
89+
for (int body_idx = 0; body_idx < num_subgraphs; body_idx++) {
90+
const auto& body = multi_sub_graph_op->get_function(body_idx);
91+
auto& body_params = body->get_parameters();
92+
auto& descriptions = multi_sub_graph_op->get_input_descriptions(body_idx);
93+
for (auto desc_it = descriptions.begin(); desc_it < descriptions.end();) {
94+
const auto& desc = *desc_it;
95+
const auto input_index = desc->m_input_index;
96+
const auto constant = try_constantfold_input(multi_sub_graph_op, desc, cache);
97+
if (!constant) {
98+
remove_inputs_mask &= ~(1 << input_index);
99+
desc_it++;
100+
continue;
101+
}
102+
const auto body_parameter_index = desc->m_body_parameter_index;
103+
desc_it = descriptions.erase(desc_it);
104+
auto& body_param = body_params[body_parameter_index];
105+
replace_body_parameter(body, body_param, body_parameter_index, constant, descriptions);
106+
remove_inputs_mask |= 1 << input_index;
107+
result = true;
108+
}
109+
}
110+
111+
if (remove_inputs_mask > 0) {
112+
update_multi_sub_graph_op_inputs(multi_sub_graph_op, remove_inputs_mask);
113+
}
114+
115+
for (int body_idx = 0; body_idx < num_subgraphs; body_idx++) {
116+
bool model_changed = run_on_model(multi_sub_graph_op->get_function(body_idx));
117+
result = result || model_changed;
118+
}
119+
}
120+
121+
return result;
122+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <openvino/core/model.hpp>
6+
#include <openvino/opsets/opset10.hpp>
7+
#include <transformations/common_optimizations/push_constant_to_subgraph.hpp>
8+
9+
#include "common_test_utils/ngraph_test_utils.hpp"
10+
11+
using namespace testing;
12+
using namespace ov;
13+
14+
TEST_F(TransformationTestsF, PushConstantToSubgraphLoop) {
15+
{
16+
auto trip_count = opset10::Constant::create(element::i32, Shape{}, {2});
17+
auto term_cond = opset10::Constant::create(element::boolean, Shape{}, {true});
18+
std::shared_ptr<Model> loop_body;
19+
{
20+
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
21+
auto Y = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
22+
auto Z = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
23+
auto mul = std::make_shared<opset10::Multiply>(X, Y);
24+
auto add = std::make_shared<opset10::Add>(mul, Z);
25+
auto cond = opset10::Constant::create(element::boolean, Shape{}, {true});
26+
loop_body = std::make_shared<Model>(NodeVector{add, cond}, ParameterVector{X, Y, Z});
27+
}
28+
auto loop = std::make_shared<opset10::Loop>(trip_count, term_cond);
29+
loop->set_function(loop_body);
30+
31+
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{2, 2});
32+
auto constant_1 = opset10::Constant::create(element::i32, Shape{2, 2}, {11});
33+
auto convert_1 = std::make_shared<opset10::Convert>(constant_1, element::f32);
34+
auto constant_2 = opset10::Constant::create(element::i32, Shape{1, 2}, {22});
35+
auto convert_2 = std::make_shared<opset10::Convert>(constant_2, element::f32);
36+
const auto& loop_params = loop_body->get_parameters();
37+
loop->set_special_body_ports({-1, 1});
38+
loop->set_sliced_input(loop_params[0], X, 0, 1, 1, -1, 0);
39+
loop->set_sliced_input(loop_params[1], convert_1, 0, 1, 1, -1, 0);
40+
loop->set_invariant_input(loop_params[2], convert_2);
41+
auto out = loop->get_concatenated_slices(loop_body->get_results()[0], 0, 1, 1, -1, 0);
42+
function = std::make_shared<Model>(OutputVector{out}, ParameterVector{X});
43+
44+
manager.register_pass<pass::PushConstantToSubgraph>();
45+
}
46+
47+
{
48+
auto trip_count = opset10::Constant::create(element::i32, Shape{}, {2});
49+
auto term_cond = opset10::Constant::create(element::boolean, Shape{}, {true});
50+
std::shared_ptr<Model> loop_body;
51+
{
52+
auto constant = opset10::Constant::create(element::f32, Shape{1, 2}, {22});
53+
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
54+
auto Y = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
55+
auto mul = std::make_shared<opset10::Multiply>(X, Y);
56+
auto add = std::make_shared<opset10::Add>(mul, constant);
57+
auto cond = opset10::Constant::create(element::boolean, Shape{}, {true});
58+
loop_body = std::make_shared<Model>(NodeVector{add, cond}, ParameterVector{X, Y});
59+
}
60+
auto loop = std::make_shared<opset10::Loop>(trip_count, term_cond);
61+
loop->set_function(loop_body);
62+
63+
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{2, 2});
64+
auto constant_1 = opset10::Constant::create(element::i32, Shape{2, 2}, {11});
65+
auto convert_1 = std::make_shared<opset10::Convert>(constant_1, element::f32);
66+
const auto& loop_params = loop_body->get_parameters();
67+
loop->set_special_body_ports({-1, 1});
68+
loop->set_sliced_input(loop_params[0], X, 0, 1, 1, -1, 0);
69+
loop->set_sliced_input(loop_params[1], convert_1, 0, 1, 1, -1, 0);
70+
auto out = loop->get_concatenated_slices(loop_body->get_results()[0], 0, 1, 1, -1, 0);
71+
function_ref = std::make_shared<Model>(OutputVector{out}, ParameterVector{X});
72+
}
73+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
74+
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
75+
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
76+
}
77+
78+
TEST_F(TransformationTestsF, PushConstantToSubgraphIf) {
79+
{
80+
auto cond = opset10::Constant::create(element::boolean, Shape{}, {false});
81+
auto if_op = std::make_shared<ov::opset10::If>(cond);
82+
std::shared_ptr<ov::Model> then_body;
83+
{
84+
auto A = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
85+
auto B = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
86+
auto C = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
87+
auto D = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
88+
auto add = std::make_shared<ov::opset10::Add>(A, B);
89+
auto mul = std::make_shared<ov::opset10::Multiply>(add, C);
90+
auto sub = std::make_shared<ov::opset10::Subtract>(mul, D);
91+
then_body = std::make_shared<ov::Model>(add, ov::ParameterVector{A, B, C, D});
92+
}
93+
std::shared_ptr<ov::Model> else_body;
94+
{
95+
auto A = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
96+
auto B = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
97+
auto C = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
98+
auto D = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
99+
auto mul = std::make_shared<ov::opset10::Multiply>(A, B);
100+
auto add = std::make_shared<ov::opset10::Add>(mul, C);
101+
auto div = std::make_shared<ov::opset10::Divide>(add, D);
102+
else_body = std::make_shared<ov::Model>(div, ov::ParameterVector{A, B, C, D});
103+
}
104+
105+
if_op->set_then_body(then_body);
106+
if_op->set_else_body(else_body);
107+
108+
const auto& then_params = then_body->get_parameters();
109+
const auto& else_params = else_body->get_parameters();
110+
111+
auto A = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
112+
auto B = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
113+
auto C = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
114+
auto const_1 = ov::opset10::Constant::create(ov::element::i32, ov::Shape{3}, {1});
115+
auto convert_1 = std::make_shared<ov::opset10::Convert>(const_1, ov::element::f32);
116+
auto const_2 = ov::opset10::Constant::create(ov::element::i32, ov::Shape{3}, {2});
117+
auto convert_2 = std::make_shared<ov::opset10::Convert>(const_2, ov::element::f32);
118+
auto const_3 = ov::opset10::Constant::create(ov::element::i32, ov::Shape{3}, {3});
119+
auto convert_3 = std::make_shared<ov::opset10::Convert>(const_3, ov::element::f32);
120+
121+
if_op->set_input(A, then_params[0], nullptr);
122+
if_op->set_input(convert_1, then_params[1], nullptr);
123+
if_op->set_input(B, then_params[2], else_params[0]);
124+
if_op->set_input(convert_2, then_params[3], else_params[1]);
125+
126+
if_op->set_input(C, nullptr, else_params[2]);
127+
if_op->set_input(convert_3, nullptr, else_params[3]);
128+
if_op->set_output(then_body->get_results()[0], else_body->get_results()[0]);
129+
130+
function = std::make_shared<ov::Model>(if_op, ov::ParameterVector{A, B, C});
131+
132+
manager.register_pass<pass::PushConstantToSubgraph>();
133+
}
134+
135+
{
136+
auto cond = opset10::Constant::create(element::boolean, Shape{}, {false});
137+
auto const_1 = ov::opset10::Constant::create(ov::element::f32, ov::Shape{3}, {1});
138+
auto const_2 = ov::opset10::Constant::create(ov::element::f32, ov::Shape{3}, {2});
139+
auto const_3 = ov::opset10::Constant::create(ov::element::f32, ov::Shape{3}, {3});
140+
auto if_op = std::make_shared<ov::opset10::If>(cond);
141+
std::shared_ptr<ov::Model> then_body;
142+
{
143+
auto A = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
144+
auto B = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
145+
auto add = std::make_shared<ov::opset10::Add>(A, const_1);
146+
auto mul = std::make_shared<ov::opset10::Multiply>(add, B);
147+
auto sub = std::make_shared<ov::opset10::Subtract>(mul, const_2);
148+
then_body = std::make_shared<ov::Model>(add, ov::ParameterVector{A, B});
149+
}
150+
std::shared_ptr<ov::Model> else_body;
151+
{
152+
auto A = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
153+
auto B = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
154+
auto mul = std::make_shared<ov::opset10::Multiply>(A, const_2);
155+
auto add = std::make_shared<ov::opset10::Add>(mul, B);
156+
auto div = std::make_shared<ov::opset10::Divide>(add, const_3);
157+
else_body = std::make_shared<ov::Model>(div, ov::ParameterVector{A, B});
158+
}
159+
160+
if_op->set_then_body(then_body);
161+
if_op->set_else_body(else_body);
162+
163+
const auto& then_params = then_body->get_parameters();
164+
const auto& else_params = else_body->get_parameters();
165+
166+
auto A = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
167+
auto B = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
168+
auto C = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3});
169+
170+
if_op->set_input(A, then_params[0], nullptr);
171+
if_op->set_input(B, then_params[1], else_params[0]);
172+
if_op->set_input(C, nullptr, else_params[1]);
173+
if_op->set_output(then_body->get_results()[0], else_body->get_results()[0]);
174+
175+
function_ref = std::make_shared<ov::Model>(if_op, ov::ParameterVector{A, B, C});
176+
}
177+
178+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
179+
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
180+
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
181+
}

0 commit comments

Comments
 (0)