Skip to content

Commit 1b6af84

Browse files
[Experimental][TorchFX] OpenVINOQuantizer (#3203)
### Changes * torch.ao `OpenVINOQuantizer` as well as `OpenVINOQuantizerAdapter` are introduced * `quantize_pt2e` function is updated to work with `OpenVINOQuantizer` ### Reason for changes * To enable OpenVINO quantization for torch.ao quantization pipelines (`torch.ao.quantization.prepare_pt2e`, `torch.ao.quantization.convert_pt2e`) and quantize_pt2e API function ### Related tickets #2766 ### Tests tests/torch/fx/test_quantizer.py is updated with use cases: - `OpenVINOQuantizer` + `quantize_pt2e` - `OpenVINOQuantizer` +`torch.ao.quantization.prepare_pt2e` -> `torch.ao.quantization.convert_pt2e`
1 parent 281149b commit 1b6af84

File tree

19 files changed

+18728
-63
lines changed

19 files changed

+18728
-63
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import torch.fx
13+
14+
from nncf.common.graph.graph import NNCFGraph
15+
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
16+
from nncf.experimental.quantization.quantizers.openvino_quantizer import OpenVINOQuantizer
17+
from nncf.experimental.quantization.quantizers.quantizer import Quantizer
18+
19+
20+
class OpenVINOQuantizerAdapter(Quantizer):
21+
"""
22+
Implementation of the NNCF Quantizer interface for the OpenVINOQuantizer.
23+
"""
24+
25+
def __init__(self, quantizer: OpenVINOQuantizer):
26+
self._quantizer = quantizer
27+
28+
def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
29+
return self._quantizer.transform_for_annotation(model)
30+
31+
def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
32+
return self._quantizer.get_quantization_setup(model, nncf_graph)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from collections import defaultdict
13+
from typing import Dict, List, Optional, Tuple, Union
14+
15+
import torch.fx
16+
from torch.ao.quantization.observer import HistogramObserver
17+
from torch.ao.quantization.observer import PerChannelMinMaxObserver
18+
from torch.ao.quantization.quantizer.quantizer import EdgeOrNode
19+
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation
20+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
21+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase
22+
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer
23+
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec
24+
25+
import nncf
26+
from nncf.common.graph.graph import NNCFGraph
27+
from nncf.common.logging import nncf_logger
28+
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule
29+
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
30+
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
31+
from nncf.common.quantization.structs import QuantizationPreset
32+
from nncf.common.quantization.structs import QuantizationScheme
33+
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
34+
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
35+
from nncf.experimental.torch.fx.transformations import fold_constant_except_qdq
36+
from nncf.parameters import ModelType
37+
from nncf.parameters import QuantizationMode
38+
from nncf.parameters import TargetDevice
39+
from nncf.quantization.advanced_parameters import FP8QuantizationParameters
40+
from nncf.quantization.advanced_parameters import OverflowFix
41+
from nncf.quantization.advanced_parameters import QuantizationParameters
42+
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
43+
from nncf.scopes import IgnoredScope
44+
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
45+
46+
QUANT_ANNOTATION_KEY = "quantization_annotation"
47+
48+
49+
class OpenVINOQuantizer(TorchAOQuantizer):
50+
"""
51+
Implementation of the Torch AO quantizer which annotates models with quantization annotations
52+
optimally for the inference via OpenVINO.
53+
"""
54+
55+
def __init__(
56+
self,
57+
*,
58+
mode: Optional[QuantizationMode] = None,
59+
preset: Optional[QuantizationPreset] = None,
60+
target_device: TargetDevice = TargetDevice.ANY,
61+
model_type: Optional[ModelType] = None,
62+
ignored_scope: Optional[IgnoredScope] = None,
63+
overflow_fix: Optional[OverflowFix] = None,
64+
quantize_outputs: bool = False,
65+
activations_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None,
66+
weights_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None,
67+
quantizer_propagation_rule: QuantizerPropagationRule = QuantizerPropagationRule.MERGE_ALL_IN_ONE,
68+
):
69+
"""
70+
:param mode: Defines optimization mode for the algorithm. None by default.
71+
:param preset: A preset controls the quantization mode (symmetric and asymmetric).
72+
It can take the following values:
73+
- `performance`: Symmetric quantization of weights and activations.
74+
- `mixed`: Symmetric quantization of weights and asymmetric quantization of activations.
75+
Default value is None. In this case, `mixed` preset is used for `transformer`
76+
model type otherwise `performance`.
77+
:param target_device: A target device the specificity of which will be taken
78+
into account while compressing in order to obtain the best performance
79+
for this type of device, defaults to TargetDevice.ANY.
80+
:param model_type: Model type is needed to specify additional patterns
81+
in the model. Supported only `transformer` now.
82+
:param ignored_scope: An ignored scope that defined the list of model control
83+
flow graph nodes to be ignored during quantization.
84+
:param overflow_fix: This option controls whether to apply the overflow issue
85+
fix for the 8-bit quantization.
86+
:param quantize_outputs: Whether to insert additional quantizers right before
87+
each of the model outputs.
88+
:param activations_quantization_params: Quantization parameters for model
89+
activations.
90+
:param weights_quantization_params: Quantization parameters for model weights.
91+
:param quantizer_propagation_rule: The strategy to be used while propagating and merging quantizers.
92+
MERGE_ALL_IN_ONE by default.
93+
"""
94+
self._min_max_algo = MinMaxQuantization(
95+
mode=mode,
96+
preset=preset,
97+
target_device=target_device,
98+
model_type=model_type,
99+
ignored_scope=ignored_scope,
100+
overflow_fix=overflow_fix,
101+
quantize_outputs=quantize_outputs,
102+
activations_quantization_params=activations_quantization_params,
103+
weights_quantization_params=weights_quantization_params,
104+
quantizer_propagation_rule=quantizer_propagation_rule,
105+
)
106+
107+
def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
108+
self._min_max_algo._set_backend_entity(model)
109+
return self._min_max_algo.find_quantization_setup(model, nncf_graph)
110+
111+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
112+
nncf_graph = GraphConverter.create_nncf_graph(model)
113+
quantization_setup = self.get_quantization_setup(model, nncf_graph)
114+
115+
graph = model.graph
116+
node_vs_torch_annotation = defaultdict(TorchAOQuantizationAnnotation)
117+
118+
for qp in quantization_setup.quantization_points.values():
119+
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
120+
graph, nncf_graph, qp, node_vs_torch_annotation
121+
)
122+
qspec = self._get_torch_ao_qspec_from_qp(qp)
123+
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
124+
125+
for quantizer_ids in quantization_setup.unified_scale_groups.values():
126+
127+
root_quantizer_id = self._get_unified_scales_root_quantizer_id(
128+
nncf_graph, quantizer_ids, quantization_setup
129+
)
130+
root_qp = quantization_setup.quantization_points[root_quantizer_id]
131+
132+
if any(root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig for q_id in quantizer_ids):
133+
qps = [quantization_setup.quantization_points[q_id] for q_id in quantizer_ids]
134+
raise nncf.InternalError(
135+
"Different quantization configs are set to one unified scale group:"
136+
f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}"
137+
)
138+
139+
root_target_node = get_graph_node_by_name(graph, root_qp.insertion_point.target_node_name)
140+
root_edge_or_node = self._get_edge_or_node(root_target_node, root_qp, nncf_graph)
141+
142+
for quantizer_id in quantizer_ids:
143+
if quantizer_id == root_quantizer_id:
144+
continue
145+
146+
qspec = TorchAOSharedQuantizationSpec(root_edge_or_node)
147+
qp = quantization_setup.quantization_points[quantizer_id]
148+
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
149+
graph, nncf_graph, qp, node_vs_torch_annotation
150+
)
151+
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
152+
153+
for node, annotation in node_vs_torch_annotation.items():
154+
assert QUANT_ANNOTATION_KEY not in node.meta
155+
node.meta[QUANT_ANNOTATION_KEY] = annotation
156+
157+
@staticmethod
158+
def _get_unified_scales_root_quantizer_id(
159+
nncf_graph: NNCFGraph, quantizer_ids: List[int], quantizer_setup: SingleConfigQuantizerSetup
160+
) -> int:
161+
"""
162+
Identifies the earliest quantizer node ID based on the corresponding `nncf_node.node_id`
163+
in the given NNCFGraph. This is required by the `_get_obs_or_fq_map` function.
164+
Refer to: https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/prepare.py#L291
165+
166+
:param nncf_graph: The NNCFGraph instance.
167+
:param quantizer_ids: The list of quantizer IDs to evaluate.
168+
:param quantizer_setup: The instance of SingleConfigQuantizerSetup.
169+
:return: The ID of the earliest quantizer node in terms of `nncf_node.node_id`.
170+
"""
171+
nncf_node_quantizer_id = None
172+
root_quantizer_id = None
173+
for quantizer_id in quantizer_ids:
174+
target_node_name = quantizer_setup.quantization_points[quantizer_id].insertion_point.target_node_name
175+
nncf_node = nncf_graph.get_node_by_name(target_node_name)
176+
if nncf_node_quantizer_id is None or nncf_node.node_id < nncf_node_quantizer_id:
177+
root_quantizer_id = quantizer_id
178+
nncf_node_quantizer_id = nncf_node.node_id
179+
return root_quantizer_id
180+
181+
@staticmethod
182+
def _get_edge_or_node_and_annotation(
183+
graph: torch.fx.Graph,
184+
nncf_graph: NNCFGraph,
185+
qp: QuantizationPointBase,
186+
node_vs_torch_annotation: Dict[torch.fx.Node, TorchAOQuantizationAnnotation],
187+
) -> Tuple[EdgeOrNode, TorchAOQuantizationAnnotation]:
188+
"""
189+
Retrieves the edge or node and its corresponding TorchAOQuantizationAnnotation based on the given graph,
190+
quantization point, and node-to-annotation mapping.
191+
192+
:param graph: torch.fx.Graph instance.
193+
:param nncf_graph: NNCFGraph instance.
194+
:param qp: QuantizationPointBase instance.
195+
:param node_vs_torch_annotation: A dictionary mapping torch.fx.GraphNode objects to their respective
196+
TorchAOQuantizationAnnotations.
197+
:return: A tuple containing the EdgeOrNode and its associated TorchAOQuantizationAnnotation.
198+
"""
199+
target_node = get_graph_node_by_name(graph, qp.insertion_point.target_node_name)
200+
annotation = node_vs_torch_annotation[target_node]
201+
edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph)
202+
return edge_or_node, annotation
203+
204+
@staticmethod
205+
def _get_edge_or_node(target_node: torch.fx.Node, qp: QuantizationPointBase, nncf_graph: NNCFGraph) -> EdgeOrNode:
206+
"""
207+
Returns the edge or node based on the given target node and quantization point.
208+
209+
:param target_node: Target node instance.
210+
:param qp: QuantizationPointBase instance.
211+
:param graph: NNCFGraph instance.
212+
:return: The corresponding EdgeOrNode derived from the target node and quantization point.
213+
"""
214+
ip = qp.insertion_point
215+
if qp.is_weight_quantization_point():
216+
nncf_node = nncf_graph.get_node_by_name(target_node.name)
217+
weights_ports_ids = get_weight_tensor_port_ids(nncf_node, nncf_graph)
218+
if len(weights_ports_ids) > 1:
219+
# TODO(dlyakhov): support quantization for nodes with several weights
220+
nncf_logger.warning(
221+
f"Quantization of the weighted node {target_node.name}"
222+
" is not yet supported by the OpenVINOQuantizer."
223+
f" Only the weight on port ID {weights_ports_ids[0]} will be quantized."
224+
f" Quantizable weights are located on ports: {weights_ports_ids}."
225+
)
226+
weight_node = target_node.all_input_nodes[weights_ports_ids[0]]
227+
return (weight_node, target_node)
228+
229+
if ip.input_port_id is None:
230+
return target_node
231+
232+
node = target_node.all_input_nodes[ip.input_port_id]
233+
return (node, target_node)
234+
235+
@staticmethod
236+
def _fill_torch_ao_annotation(
237+
edge_or_node: EdgeOrNode,
238+
qspec: TorchAOQuantizationSpecBase,
239+
annotation_to_update: TorchAOQuantizationAnnotation,
240+
) -> None:
241+
"""
242+
Helper method to update the annotation_to_update based on the specified edge_or_node and qspec.
243+
244+
:param edge_or_node: The target EdgeOrNode to be used for the update.
245+
:param qspec: An instance of TorchAOQuantizationSpecBase representing the quantization specification to apply.
246+
:param annotation_to_update: The annotation to update based on the edge_or_node and qspec.
247+
"""
248+
if isinstance(edge_or_node, torch.fx.Node):
249+
annotation_to_update.output_qspec = qspec
250+
else:
251+
annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec
252+
253+
@staticmethod
254+
def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> TorchAOQuantizationSpec:
255+
"""
256+
Retrieves the quantization configuration from the given quantization point and
257+
converts it into a TorchAOQuantizationSpec.
258+
259+
:param qp: An instance of QuantizationPointBase.
260+
:return: A TorchAOQuantizationSpec retrieved and converted from the quantization point.
261+
"""
262+
# Eps value is copied from nncf/torch/quantization/layers.py
263+
extra_args = {"eps": 1e-16}
264+
qconfig = qp.qconfig
265+
is_weight = qp.is_weight_quantization_point()
266+
267+
if qconfig.per_channel:
268+
torch_qscheme = (
269+
torch.per_channel_symmetric
270+
if qconfig.mode is QuantizationScheme.SYMMETRIC
271+
else torch.per_channel_affine
272+
)
273+
else:
274+
torch_qscheme = (
275+
torch.per_tensor_symmetric if qconfig.mode is QuantizationScheme.SYMMETRIC else torch.per_tensor_affine
276+
)
277+
if is_weight:
278+
observer = PerChannelMinMaxObserver
279+
quant_min = -128
280+
quant_max = 127
281+
dtype = torch.int8
282+
channel_axis = 0
283+
else:
284+
observer = (
285+
HistogramObserver
286+
if torch_qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
287+
else PerChannelMinMaxObserver
288+
)
289+
quant_min = 0
290+
quant_max = 255
291+
dtype = torch.int8 if qconfig.signedness_to_force else torch.uint8
292+
channel_axis = 1 # channel dim for activations
293+
return TorchAOQuantizationSpec(
294+
dtype=dtype,
295+
observer_or_fake_quant_ctr=observer.with_args(**extra_args),
296+
quant_min=quant_min,
297+
quant_max=quant_max,
298+
qscheme=torch_qscheme,
299+
ch_axis=channel_axis,
300+
is_dynamic=False,
301+
)
302+
303+
def validate(self, model: torch.fx.GraphModule) -> None:
304+
pass
305+
306+
def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
307+
fold_constant_except_qdq(model)
308+
return model

nncf/experimental/torch/fx/quantization/quantize_pt2e.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@
2727
from nncf.common.logging import nncf_logger
2828
from nncf.data import Dataset
2929
from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization
30+
from nncf.experimental.quantization.quantizers.openvino_adapter import OpenVINOQuantizerAdapter
31+
from nncf.experimental.quantization.quantizers.openvino_quantizer import OpenVINOQuantizer
3032
from nncf.experimental.quantization.quantizers.torch_ao_adapter import TorchAOQuantizerAdapter
3133
from nncf.experimental.torch.fx.constant_folding import constant_fold
3234
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
35+
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
3336
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
3437
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
3538
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
@@ -90,7 +93,11 @@ def quantize_pt2e(
9093
model = deepcopy(model)
9194

9295
_fuse_conv_bn_(model)
93-
quantizer = TorchAOQuantizerAdapter(quantizer)
96+
if isinstance(quantizer, OpenVINOQuantizer):
97+
quantizer = OpenVINOQuantizerAdapter(quantizer)
98+
else:
99+
quantizer = TorchAOQuantizerAdapter(quantizer)
100+
94101
# Call transform_prior_quantization before the NNCFGraph creation
95102
transformed_model = quantizer.transform_prior_quantization(model)
96103

@@ -114,7 +121,10 @@ def quantize_pt2e(
114121
quantized_model = GraphModule(quantized_model, quantized_model.graph)
115122

116123
if fold_quantize:
117-
constant_fold(quantized_model, _quant_node_constraint)
124+
if isinstance(quantizer, OpenVINOQuantizerAdapter):
125+
compress_post_quantize_transformation(quantized_model)
126+
else:
127+
constant_fold(quantized_model, _quant_node_constraint)
118128

119129
pm = PassManager([DuplicateDQPass()])
120130

0 commit comments

Comments
 (0)