Skip to content

Commit 438871e

Browse files
Dynamic OV model builder (#3137)
### Changes - Added `ModelBuilder` class. - Updated FastBC algorithm to utilize new approach. ### Reason for changes - Algorithm speed up. ### Related tickets - 122317 ### Tests - Added tests/openvino/native/test_model_builder.py. - Conversion jobs were run using DLB scope. Results (develop run - manual/post_training_quantization_performance/91, 92; branch - manual/post_training_quantization_performance/90) Model | Backend | FBC time (develop, OV) | FBC time (develop, PT) | FBC time (branch, OV) | Diff (develop - branch, OV) -- | -- | -- | -- | -- | -- hf/bert-base-uncased | OV | 00:00:02 | - | 00:00:01 | **00:00:01** torchvision/resnet18 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 torchvision/mobilenet_v3_small_BC | OV | 00:00:01 | - | 00:00:01 | 00:00:00 torchvision/vit_b_16 | OV | 00:00:02 | - | 00:00:01 | **00:00:01** torchvision/swin_v2_s | OV | 00:00:05 | - | 00:00:01 | **00:00:04** timm/crossvit_9_240 | OV | 00:00:01 | 00:00:00 | **00:00:01** | 00:00:00 timm/darknet53 | OV | 00:00:01 | 00:00:00 | **00:00:01** | 00:00:00 timm/deit3_small_patch16_224 | OV | 00:00:01 | 00:00:00 | 00:00:00 | **00:00:01** timm/dla34 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/dpn68 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/efficientnet_b0 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/efficientnet_b0_BC | OV | 00:00:04 | - | 00:00:04 | 00:00:00 timm/efficientnet_lite0 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/hrnet_w18 | OV | 00:00:13 | 00:00:03 | **00:00:04** | **00:00:09** timm/inception_resnet_v2 | OV | 00:00:08 | 00:00:03 | **00:00:04** | **00:00:04** timm/levit_128 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/mobilenetv2_050 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/mobilenetv2_050_BC | OV | 00:00:03 | - | 00:00:03 | 00:00:00 timm/mobilenetv3_small_050 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/mobilenetv3_small_050_BC | OV | 00:00:01 | - | 00:00:01 | 00:00:00 timm/regnetx_002 | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/resnest14d | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/swin_base_patch4_window7_224 | OV | 00:00:03 | 00:00:00 | **00:00:02** | **00:00:01** timm/tf_inception_v3 | OV | 00:00:01 | 00:00:01 | 00:00:01 | 00:00:00 timm/vgg11 | OV | 00:00:01 | 00:00:00 | **00:00:01** | 00:00:00 timm/visformer_small | OV | 00:00:00 | 00:00:00 | 00:00:00 | 00:00:00 timm/wide_resnet50_2 | OV | 00:00:01 | 00:00:00 | **00:00:01** | 00:00:00
1 parent 883c787 commit 438871e

File tree

10 files changed

+377
-26
lines changed

10 files changed

+377
-26
lines changed

nncf/openvino/graph/model_builder.py

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from collections import deque
12+
from typing import Dict, List, Tuple
13+
14+
import openvino.runtime as ov
15+
from openvino.runtime import opset13 as opset
16+
from openvino.runtime.utils.node_factory import NodeFactory
17+
18+
from nncf.openvino.graph.model_transformer import OVModelTransformer
19+
from nncf.openvino.graph.node_utils import get_parameter_node_name
20+
from nncf.openvino.graph.node_utils import get_result_node_name
21+
22+
23+
class OVModelBuilder:
24+
"""
25+
The purpose of the ModelBuilder is to build a new OpenVINO model from input and output points.
26+
This Builder was created to reduce the number of model cloning that is required for ModelTransformer to work.
27+
"""
28+
29+
def __init__(self):
30+
self._node_factory = NodeFactory()
31+
32+
@staticmethod
33+
def _create_parameter(node_name: str, node_input: ov.Input) -> ov.Node:
34+
"""
35+
A method that contains steps to create a Parameter for a new model using a specific template.
36+
"""
37+
port_id = node_input.get_index()
38+
parameter_name = get_parameter_node_name(node_name, port_id)
39+
return opset.parameter(
40+
shape=node_input.get_partial_shape(),
41+
dtype=node_input.get_element_type(),
42+
name=parameter_name,
43+
)
44+
45+
@staticmethod
46+
def _create_result(node_name: str, node_output: ov.Input) -> ov.Node:
47+
"""
48+
A method that contains steps to create a Result for a new model using a specific template.
49+
"""
50+
port_id = node_output.get_index()
51+
result_name = get_result_node_name(node_name, port_id=port_id)
52+
result = opset.result(node_output, name=result_name)
53+
result.get_output_tensor(0).set_names({result_name})
54+
return result
55+
56+
def _collect_graph_nodes(
57+
self,
58+
input_ids: List[Tuple[str, int]],
59+
output_ids: List[Tuple[str, int]],
60+
node_mapping: Dict[str, ov.Node],
61+
) -> List[ov.Node]:
62+
"""
63+
A method for aggregating layers to be further cloned.
64+
Aggregation is designed in such a way that layers are listed from right to left,
65+
as they pass from bottom to top. This is done in order to find all constants in the model and
66+
to start graph creation from them (as well as Parameter layers), because
67+
OpenVINO graph is created from top-down and cannot be created otherwise.
68+
69+
Legend: w - weigths, c - convert, il/ih - input low/high, ol/oh - output low/high
70+
(w)
71+
|
72+
(c) (il) (ih) (ol) (oh)
73+
\ | | / /
74+
(fake quantize) (parameter)
75+
\ /
76+
(convolution)
77+
|
78+
(result)
79+
Based on the above graph, the return value would look like this:
80+
[convolution, parameter, fake quantize, oh, ol, ih, il, c, w]
81+
82+
:param input_ids: List of the points in the special format - (node_name, port_id).
83+
This helps to point to the precise part of the model that may be used to define the subgraph inputs.
84+
:param output_ids: List of the points in the special format - (node_name, port_id).
85+
This helps to point to the precise part of the model that may be used to define the subgraph outputs.
86+
:param node_mapping: Original nodes mapping.
87+
:return: List of the ov.Nodes to clone.
88+
"""
89+
# Creating a list as a deque for FIFO layer acquisition and retrieval
90+
lookup_nodes = deque(node_mapping[n] for n, _ in output_ids)
91+
graph_nodes = []
92+
93+
while lookup_nodes:
94+
lookup_node = lookup_nodes.popleft()
95+
lookup_name = lookup_node.get_friendly_name()
96+
node_inputs = lookup_node.inputs()
97+
graph_nodes.append(lookup_node)
98+
# Reversing to lookup nodes from right to left
99+
for node_input in reversed(node_inputs):
100+
port_id = node_input.get_index()
101+
if (lookup_name, port_id) in input_ids:
102+
# We create Parameters here to avoid double creation in the future since it is not an original node,
103+
# but we need to have it as input for next node.
104+
parameter = self._create_parameter(lookup_name, node_input)
105+
lookup_nodes.append(parameter)
106+
continue
107+
parent_node = node_input.get_source_output().get_node()
108+
lookup_nodes.append(parent_node)
109+
110+
return graph_nodes
111+
112+
def build(
113+
self,
114+
input_ids: List[Tuple[str, int]],
115+
output_ids: List[Tuple[str, int]],
116+
node_mapping: Dict[str, ov.Node],
117+
) -> ov.Model:
118+
"""
119+
The basic method of the algorithm. This method uses an aggregated list of layers to be recreated.
120+
Let us take a graph of this kind as an example:
121+
122+
Legend: w - weigths, c - convert, il/ih - input low/high, ol/oh - output low/high
123+
(w)
124+
|
125+
(c) (il) (ih) (ol) (oh)
126+
\ | | / /
127+
(fake quantize) (parameter)
128+
\ /
129+
(convolution)
130+
|
131+
(result)
132+
133+
The externally collected list of layers will look like this:
134+
[convolution, parameter, fake quantize, oh, ol, ih, il, c, w]
135+
136+
Next, this list will be circled from right to left. At the same time, the list of already created layers
137+
will be filled from left to right, which will be used in the traversal step also, from left to right,
138+
in order to keep the order of the original layer inputs.
139+
For example:
140+
141+
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il, c, w]
142+
clone_nodes = []
143+
144+
*creating w - weight node.*
145+
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il, c]
146+
clone_nodes = [w]
147+
148+
*creating c - convert node.
149+
Based on the .inputs() output, we'll use the already created w-weight node to fill in the convert input.
150+
As the result, weight node would be removed from the clone_nodes list and convert node would be placed here.*
151+
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il]
152+
clone_nodes = [c]
153+
154+
*creating il/ih - input low/high, ol/oh - output low/high nodes.
155+
Since these nodes are constants and do not require any nodes as inputs, cloned nodes will not be used.*
156+
graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il]
157+
clone_nodes = [c, il, ih, ol, oh]
158+
159+
*creating fake quantize node.
160+
This node requires to have input values in a specific order.
161+
All previous nodes will be connected/used for fake quantize, from left to right.*
162+
graph_nodes = [convolution, parameter]
163+
clone_nodes = [f]
164+
165+
*creating parameter node.
166+
In this step, the list of parameters will also be filled out with the new node.*
167+
graph_nodes = [convolution]
168+
clone_nodes = [f, parameter]
169+
170+
*creating convolution node.
171+
This node also requires to have inputs in a specific order.
172+
All previous nodes will be connected/used for convolution, from left to right. Also,
173+
the outputs verification step will show here that one of the convolution outputs is in the output_ids list.
174+
This means that the Result node would be created and placed into the results list.*
175+
graph_nodes = []
176+
clone_nodes = [convolution]
177+
178+
The last step is to create a subgraph model based on the parameters & results lists.
179+
180+
:param input_ids: List of the points in the special format - (node_name, port_id).
181+
This helps to point to the precise part of the model that may be used to define the subgraph inputs.
182+
:param output_ids: List of the points in the special format - (node_name, port_id).
183+
This helps to point to the precise part of the model that may be used to define the subgraph outputs.
184+
:param node_mapping: Original nodes mapping.
185+
:return: Builded ov.Model based on parameters.
186+
"""
187+
188+
parameters, results = [], []
189+
clone_nodes = deque()
190+
191+
# Collecting nodes that declares the graph.
192+
graph_nodes = self._collect_graph_nodes(input_ids, output_ids, node_mapping)
193+
194+
while graph_nodes:
195+
graph_node = graph_nodes.pop()
196+
node_type = graph_node.get_type_name()
197+
node_name = graph_node.get_friendly_name()
198+
199+
# To create the new OpenVINO nodes, we need to provide all possible layer attributes.
200+
attrs = graph_node.get_attributes()
201+
attrs["name"] = node_name
202+
203+
if node_type == "Constant":
204+
# Constants creation is apart due to specific behavior.
205+
clone_node = OVModelTransformer._create_constant(
206+
graph_node.get_data(), dtype=graph_node.get_element_type(), name=attrs["name"]
207+
)
208+
elif node_type == "Parameter":
209+
# We've created Parameter nodes on the previous step.
210+
clone_node = graph_node
211+
parameters.append(clone_node)
212+
else:
213+
# We have to have args as the inputs since all of them are nodes and are required to be as input.
214+
args = [clone_nodes.popleft() for _ in graph_node.inputs()]
215+
216+
clone_node = self._node_factory.create(node_type, args, attrs)
217+
218+
for node_output in clone_node.outputs():
219+
port_id = node_output.get_index()
220+
if (node_name, port_id) in output_ids:
221+
result = self._create_result(node_name, node_output)
222+
results.append(result)
223+
224+
clone_nodes.append(clone_node)
225+
226+
return ov.Model(results, parameters)

nncf/openvino/graph/node_utils.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,22 @@ def get_const_value(const_node: ov.Node) -> np.ndarray:
121121
return const_node.data
122122

123123

124-
def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray:
124+
def get_bias_value(
125+
node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model, node_mapping: Dict[str, ov.Node] = None
126+
) -> np.ndarray:
125127
"""
126128
Returns the bias tensor for the biased node.
127129
128130
:param node_with_bias: The node that corresponds to the operation with bias.
129131
:param nncf_graph: NNCFGraph instance.
130132
:param model: The model that contains this operation.
133+
:param node_mapping: Original nodes mapping cache.
131134
:return: The bias value that is applied to the output tensor of the node's operation.
132135
"""
133-
ops_dict = {op.get_friendly_name(): op for op in model.get_ops()}
136+
if node_mapping is None:
137+
node_mapping = {op.get_friendly_name(): op for op in model.get_ops()}
134138
bias_constant = get_node_with_bias_value(get_add_bias_node(node_with_bias, nncf_graph), nncf_graph)
135-
ov_bias_constant = ops_dict[bias_constant.node_name]
139+
ov_bias_constant = node_mapping[bias_constant.node_name]
136140
return get_const_value(ov_bias_constant)
137141

138142

nncf/quantization/algorithms/fast_bias_correction/algorithm.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from nncf.common.factory import EngineFactory
1818
from nncf.common.factory import ModelTransformerFactory
1919
from nncf.common.graph.graph import NNCFGraph
20-
from nncf.common.graph.model_transformer import ModelTransformer
2120
from nncf.common.graph.transformations.commands import TargetPoint
2221
from nncf.common.graph.transformations.commands import TargetType
2322
from nncf.common.graph.transformations.layout import TransformationLayout
@@ -111,7 +110,7 @@ def _set_backend_entity(self, model: TModel) -> None:
111110
OVFastBiasCorrectionAlgoBackend,
112111
)
113112

114-
self._backend_entity = OVFastBiasCorrectionAlgoBackend()
113+
self._backend_entity = OVFastBiasCorrectionAlgoBackend(model)
115114
elif model_backend == BackendType.TORCH:
116115
from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend
117116

@@ -167,7 +166,7 @@ def apply(
167166
# Outputs of the subgraphs for the FastBiasCorrection are the same across the backends.
168167
output_id = (out_node_name, 0)
169168

170-
extracted_model = self._extract_submodel(model_transformer, input_id, output_id)
169+
extracted_model = self._backend_entity.extract_submodel(model_transformer, input_id, output_id)
171170
if extracted_model is None:
172171
nncf_logger.debug(f"Skipping node {node_name} because cant extract submodel")
173172
continue
@@ -287,23 +286,6 @@ def output_filter_func(point):
287286
output_fp.extend(tensor_collector.get_statistics().mean_values)
288287
return output_fp
289288

290-
def _extract_submodel(
291-
self, model_transformer: ModelTransformer, input_id: Tuple[str, int], output_id: Tuple[str, int]
292-
) -> TModel:
293-
"""
294-
Extracts sub-model using backend-specific ModelTransformer.
295-
296-
:param model_transformer: Backend-specific ModelTransformer.
297-
:param input_id: Input ID.
298-
:param output_id: Output ID.
299-
:return: Backend-specific sub-model.
300-
"""
301-
model_extraction_command = self._backend_entity.model_extraction_command([input_id], [output_id])
302-
me_transformation_layout = TransformationLayout()
303-
me_transformation_layout.register(model_extraction_command)
304-
extracted_model = model_transformer.transform(me_transformation_layout)
305-
return extracted_model
306-
307289
def _add_statistic_point(self, container: StatisticPointsContainer, point: TargetPoint, axis: int) -> None:
308290
"""
309291
Adds specific statistic point.

nncf/quantization/algorithms/fast_bias_correction/backend.py

+19
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
from nncf.common.graph import NNCFGraph
1717
from nncf.common.graph import NNCFNode
18+
from nncf.common.graph.model_transformer import ModelTransformer
1819
from nncf.common.graph.transformations.commands import TargetPoint
1920
from nncf.common.graph.transformations.commands import TargetType
2021
from nncf.common.graph.transformations.commands import TransformationCommand
22+
from nncf.common.graph.transformations.layout import TransformationLayout
2123
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
2224
from nncf.tensor import Tensor
2325

@@ -194,3 +196,20 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple
194196
:param input_shape: Shape of the input.
195197
:return: Channel axis number.
196198
"""
199+
200+
def extract_submodel(
201+
self, model_transformer: ModelTransformer, input_id: Tuple[str, int], output_id: Tuple[str, int]
202+
) -> TModel:
203+
"""
204+
Extracts sub-model using backend-specific ModelTransformer.
205+
206+
:param model_transformer: Backend-specific ModelTransformer.
207+
:param input_id: Input ID.
208+
:param output_id: Output ID.
209+
:return: Backend-specific sub-model.
210+
"""
211+
model_extraction_command = self.model_extraction_command([input_id], [output_id])
212+
me_transformation_layout = TransformationLayout()
213+
me_transformation_layout.register(model_extraction_command)
214+
extracted_model = model_transformer.transform(me_transformation_layout)
215+
return extracted_model

nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
2121
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
2222
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS_REDUCED
23+
from nncf.openvino.graph.model_builder import OVModelBuilder
2324
from nncf.openvino.graph.node_utils import get_activation_channel_axis
2425
from nncf.openvino.graph.node_utils import get_bias_value
2526
from nncf.openvino.graph.node_utils import is_node_with_bias
@@ -33,6 +34,12 @@
3334

3435

3536
class OVFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend):
37+
38+
def __init__(self, model):
39+
# Node mapping caching to reduce time for calculations
40+
self._node_mapping = {op.get_friendly_name(): op for op in model.get_ops()}
41+
self._model_builder = OVModelBuilder()
42+
3643
@staticmethod
3744
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
3845
return OVTargetPoint(target_type, target_node_name, port_id)
@@ -73,9 +80,8 @@ def create_input_data(
7380
input_data = {input_name: blob}
7481
return input_data
7582

76-
@staticmethod
77-
def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor:
78-
return Tensor(get_bias_value(node, nncf_graph, model))
83+
def get_bias_value(self, node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor:
84+
return Tensor(get_bias_value(node, nncf_graph, model, node_mapping=self._node_mapping))
7985

8086
@staticmethod
8187
def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]:
@@ -113,3 +119,11 @@ def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFG
113119
@staticmethod
114120
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
115121
return get_activation_channel_axis(node, port_id, input_shape)
122+
123+
def extract_submodel(self, model_transformer, input_id, output_id):
124+
125+
return self._model_builder.build(
126+
input_ids=[input_id],
127+
output_ids=[output_id],
128+
node_mapping=self._node_mapping,
129+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
strict digraph {
2+
"0 Parameter_Conv.0" [id=0, type=Parameter];
3+
"1 Convolution_57" [id=1, type=Convolution];
4+
"2 Result_Conv.0" [id=2, type=Result];
5+
"3 Conv/Constant_4" [id=3, type=Constant];
6+
"0 Parameter_Conv.0" -> "1 Convolution_57" [label="[1, 3, 4, 2]", style=solid];
7+
"1 Convolution_57" -> "2 Result_Conv.0" [label="[1, 3, 4, 2]", style=solid];
8+
"3 Conv/Constant_4" -> "1 Convolution_57" [label="[3, 3, 1, 1]", style=solid];
9+
}

0 commit comments

Comments
 (0)