Skip to content

Commit 5d00273

Browse files
itikhonoandrei-kochinCuriousPanCake
authored
Add transformation pipeline to PrePostProcessing (#28852)
### Details: After switching from ModelOptimizer to OVC, the order of applying PrePostProcessing and MOCTransformations has changed: MO path : [fw model conversion -> PrePostProcessing -> MOC] -> nncf OVC path: [fw model conversion -> MOC] -> PrePostProcessing -> nncf Since nncf is applied with a not fully optimized model, extra FQ ops might appear, which can affect both accuracy and performance. e.g. Mul -> Conv fusion is not applied due to extra FQ <img width="165" alt="{C6E93F2C-2CE3-4596-8D7F-ED7BD8013603}" src="https://github.com/user-attachments/assets/3cbe6e07-9c07-4002-8b4c-9fb5bc662421" /> PrePostProcessing is not part of OVC, so we have to insert additional Transformation calls inside PrePostProcessing. ### Tickets: - *CVS-160786* - CVS-161724 --------- Co-authored-by: Andrei Kochin <andrei.kochin@intel.com> Co-authored-by: Andrii Staikov <andrii.staikov@intel.com>
1 parent 74126de commit 5d00273

File tree

8 files changed

+270
-86
lines changed

8 files changed

+270
-86
lines changed

src/bindings/python/tests/test_graph/test_preprocess.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def test_graph_preprocess_scale_vector():
7272
assert list(model.get_output_shape(0)) == [2, 2]
7373
assert model.get_output_element_type(0) == Type.f32
7474
assert "Constant" in model_operators
75-
assert "Divide" in model_operators
75+
# Div will be converted to Mul in the transformations
76+
assert "Multiply" in model_operators
7677

7778

7879
def test_graph_preprocess_mean_scale_convert():
@@ -95,12 +96,13 @@ def custom_preprocess(output: Output):
9596
model = ppp.build()
9697

9798
model_operators = [op.get_name().split("_")[0] for op in model.get_ops()]
99+
# Div will be converted to Mul in the transformations
98100
expected_ops = [
99101
"Parameter",
100102
"Convert",
101103
"Constant",
102104
"Subtract",
103-
"Divide",
105+
"Multiply",
104106
"Result",
105107
"Abs",
106108
]
@@ -137,12 +139,13 @@ def custom_preprocess(output: Output):
137139
model = ppp.build()
138140

139141
model_operators = [op.get_name().split("_")[0] for op in model.get_ops()]
142+
# Div will be converted to Mul in the transformations
140143
expected_ops = [
141144
"Parameter",
142145
"Convert",
143146
"Constant",
144147
"Subtract",
145-
"Divide",
148+
"Multiply",
146149
"Result",
147150
"Abs",
148151
]
@@ -404,7 +407,7 @@ def test_graph_preprocess_steps(algorithm, color_format1, color_format2, is_fail
404407
"Gather",
405408
"Interpolate",
406409
]
407-
assert len(model_operators) == 15
410+
assert len(model_operators) == 12
408411
assert model.get_output_size() == 1
409412
assert list(model.get_output_shape(0)) == [1, 3, 3, 3]
410413
assert model.get_output_element_type(0) == Type.f32
@@ -456,10 +459,9 @@ def test_graph_preprocess_postprocess_layout():
456459
"Constant",
457460
"Result",
458461
"Gather",
459-
"Range",
460462
"Transpose",
461463
]
462-
assert len(model_operators) == 14
464+
assert len(model_operators) == 11
463465
assert model.get_output_size() == 1
464466
assert list(model.get_output_shape(0)) == [1, 1, 3, 3]
465467
assert model.get_output_element_type(0) == Type.f32
@@ -486,9 +488,8 @@ def test_graph_preprocess_reverse_channels():
486488
"Constant",
487489
"Result",
488490
"Gather",
489-
"Range",
490491
]
491-
assert len(model_operators) == 10
492+
assert len(model_operators) == 7
492493
assert model.get_output_size() == 1
493494
assert list(model.get_output_shape(0)) == [1, 2, 2, 2]
494495
assert model.get_output_element_type(0) == Type.f32
@@ -628,6 +629,7 @@ def custom_preprocess(output: Output):
628629
model = ppp.build()
629630

630631
model_operators = [op.get_name().split("_")[0] for op in model.get_ops()]
632+
# Div will be converted to Mul in the transformations
631633
expected_ops = [
632634
"Parameter",
633635
"Constant",
@@ -636,7 +638,7 @@ def custom_preprocess(output: Output):
636638
"Convert",
637639
"Abs",
638640
"Add",
639-
"Divide",
641+
"Multiply",
640642
]
641643
assert len(model_operators) == 13
642644
assert model.get_output_size() == 1

src/common/transformations/include/transformations/rt_info/dequantization_node.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ TRANSFORMATIONS_API void mark_as_dequantization_node(const std::shared_ptr<Node>
1414

1515
TRANSFORMATIONS_API bool is_dequantization_node(const std::shared_ptr<const Node>& node);
1616

17+
TRANSFORMATIONS_API void unmark_dequantization_node(const std::shared_ptr<Node>& node);
18+
1719
/**
1820
* @ingroup ov_runtime_attr_api
1921
* @brief DequantizationNode class represents runtime info attribute that marks operation

src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,15 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
131131
using namespace ov::pass;
132132
REGISTER_PASS(manager, InitNodeInfo)
133133
if (m_low_precision_enabled) {
134-
manager.register_pass<ov::pass::MarkDequantization>(
135-
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
134+
manager.register_pass<ov::pass::MarkDequantization>(element::TypeVector{ov::element::i8,
135+
ov::element::u8,
136+
ov::element::i4,
137+
ov::element::u4,
138+
ov::element::nf4,
139+
ov::element::f4e2m1,
140+
ov::element::f8e4m3,
141+
ov::element::f8e5m2,
142+
ov::element::f8e8m0});
136143
}
137144
if (!m_use_shapes) {
138145
manager.register_pass<ov::pass::DisableShapeOfConstantFolding>();

src/common/transformations/src/transformations/rt_info/dequantization_node.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ void ov::mark_as_dequantization_node(const std::shared_ptr<Node>& node) {
99
rt_info[DequantizationNode::get_type_info_static()] = DequantizationNode();
1010
}
1111

12+
void ov::unmark_dequantization_node(const std::shared_ptr<Node>& node) {
13+
node->get_rt_info().erase(DequantizationNode::get_type_info_static());
14+
}
15+
1216
bool ov::is_dequantization_node(const std::shared_ptr<const Node>& node) {
1317
const auto& rt_info = node->get_rt_info();
1418
return rt_info.find(DequantizationNode::get_type_info_static()) != rt_info.end();

0 commit comments

Comments
 (0)