|
| 1 | +# Copyright (c) 2025 Intel Corporation |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from collections import defaultdict |
| 13 | +from typing import Dict, List, Optional, Tuple, Union |
| 14 | + |
| 15 | +import torch.fx |
| 16 | +from torch.ao.quantization.observer import HistogramObserver |
| 17 | +from torch.ao.quantization.observer import PerChannelMinMaxObserver |
| 18 | +from torch.ao.quantization.quantizer.quantizer import EdgeOrNode |
| 19 | +from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation |
| 20 | +from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec |
| 21 | +from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase |
| 22 | +from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer |
| 23 | +from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec |
| 24 | + |
| 25 | +import nncf |
| 26 | +from nncf.common.graph.graph import NNCFGraph |
| 27 | +from nncf.common.logging import nncf_logger |
| 28 | +from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule |
| 29 | +from nncf.common.quantization.quantizer_setup import QuantizationPointBase |
| 30 | +from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup |
| 31 | +from nncf.common.quantization.structs import QuantizationPreset |
| 32 | +from nncf.common.quantization.structs import QuantizationScheme |
| 33 | +from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter |
| 34 | +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name |
| 35 | +from nncf.experimental.torch.fx.transformations import fold_constant_except_qdq |
| 36 | +from nncf.parameters import ModelType |
| 37 | +from nncf.parameters import QuantizationMode |
| 38 | +from nncf.parameters import TargetDevice |
| 39 | +from nncf.quantization.advanced_parameters import FP8QuantizationParameters |
| 40 | +from nncf.quantization.advanced_parameters import OverflowFix |
| 41 | +from nncf.quantization.advanced_parameters import QuantizationParameters |
| 42 | +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization |
| 43 | +from nncf.scopes import IgnoredScope |
| 44 | +from nncf.torch.model_graph_manager import get_weight_tensor_port_ids |
| 45 | + |
| 46 | +QUANT_ANNOTATION_KEY = "quantization_annotation" |
| 47 | + |
| 48 | + |
| 49 | +class OpenVINOQuantizer(TorchAOQuantizer): |
| 50 | + """ |
| 51 | + Implementation of the Torch AO quantizer which annotates models with quantization annotations |
| 52 | + optimally for the inference via OpenVINO. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + *, |
| 58 | + mode: Optional[QuantizationMode] = None, |
| 59 | + preset: Optional[QuantizationPreset] = None, |
| 60 | + target_device: TargetDevice = TargetDevice.ANY, |
| 61 | + model_type: Optional[ModelType] = None, |
| 62 | + ignored_scope: Optional[IgnoredScope] = None, |
| 63 | + overflow_fix: Optional[OverflowFix] = None, |
| 64 | + quantize_outputs: bool = False, |
| 65 | + activations_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None, |
| 66 | + weights_quantization_params: Optional[Union[QuantizationParameters, FP8QuantizationParameters]] = None, |
| 67 | + quantizer_propagation_rule: QuantizerPropagationRule = QuantizerPropagationRule.MERGE_ALL_IN_ONE, |
| 68 | + ): |
| 69 | + """ |
| 70 | + :param mode: Defines optimization mode for the algorithm. None by default. |
| 71 | + :param preset: A preset controls the quantization mode (symmetric and asymmetric). |
| 72 | + It can take the following values: |
| 73 | + - `performance`: Symmetric quantization of weights and activations. |
| 74 | + - `mixed`: Symmetric quantization of weights and asymmetric quantization of activations. |
| 75 | + Default value is None. In this case, `mixed` preset is used for `transformer` |
| 76 | + model type otherwise `performance`. |
| 77 | + :param target_device: A target device the specificity of which will be taken |
| 78 | + into account while compressing in order to obtain the best performance |
| 79 | + for this type of device, defaults to TargetDevice.ANY. |
| 80 | + :param model_type: Model type is needed to specify additional patterns |
| 81 | + in the model. Supported only `transformer` now. |
| 82 | + :param ignored_scope: An ignored scope that defined the list of model control |
| 83 | + flow graph nodes to be ignored during quantization. |
| 84 | + :param overflow_fix: This option controls whether to apply the overflow issue |
| 85 | + fix for the 8-bit quantization. |
| 86 | + :param quantize_outputs: Whether to insert additional quantizers right before |
| 87 | + each of the model outputs. |
| 88 | + :param activations_quantization_params: Quantization parameters for model |
| 89 | + activations. |
| 90 | + :param weights_quantization_params: Quantization parameters for model weights. |
| 91 | + :param quantizer_propagation_rule: The strategy to be used while propagating and merging quantizers. |
| 92 | + MERGE_ALL_IN_ONE by default. |
| 93 | + """ |
| 94 | + self._min_max_algo = MinMaxQuantization( |
| 95 | + mode=mode, |
| 96 | + preset=preset, |
| 97 | + target_device=target_device, |
| 98 | + model_type=model_type, |
| 99 | + ignored_scope=ignored_scope, |
| 100 | + overflow_fix=overflow_fix, |
| 101 | + quantize_outputs=quantize_outputs, |
| 102 | + activations_quantization_params=activations_quantization_params, |
| 103 | + weights_quantization_params=weights_quantization_params, |
| 104 | + quantizer_propagation_rule=quantizer_propagation_rule, |
| 105 | + ) |
| 106 | + |
| 107 | + def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup: |
| 108 | + self._min_max_algo._set_backend_entity(model) |
| 109 | + return self._min_max_algo.find_quantization_setup(model, nncf_graph) |
| 110 | + |
| 111 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 112 | + nncf_graph = GraphConverter.create_nncf_graph(model) |
| 113 | + quantization_setup = self.get_quantization_setup(model, nncf_graph) |
| 114 | + |
| 115 | + graph = model.graph |
| 116 | + node_vs_torch_annotation = defaultdict(TorchAOQuantizationAnnotation) |
| 117 | + |
| 118 | + for qp in quantization_setup.quantization_points.values(): |
| 119 | + edge_or_node, annotation = self._get_edge_or_node_and_annotation( |
| 120 | + graph, nncf_graph, qp, node_vs_torch_annotation |
| 121 | + ) |
| 122 | + qspec = self._get_torch_ao_qspec_from_qp(qp) |
| 123 | + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) |
| 124 | + |
| 125 | + for quantizer_ids in quantization_setup.unified_scale_groups.values(): |
| 126 | + |
| 127 | + root_quantizer_id = self._get_unified_scales_root_quantizer_id( |
| 128 | + nncf_graph, quantizer_ids, quantization_setup |
| 129 | + ) |
| 130 | + root_qp = quantization_setup.quantization_points[root_quantizer_id] |
| 131 | + |
| 132 | + if any(root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig for q_id in quantizer_ids): |
| 133 | + qps = [quantization_setup.quantization_points[q_id] for q_id in quantizer_ids] |
| 134 | + raise nncf.InternalError( |
| 135 | + "Different quantization configs are set to one unified scale group:" |
| 136 | + f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" |
| 137 | + ) |
| 138 | + |
| 139 | + root_target_node = get_graph_node_by_name(graph, root_qp.insertion_point.target_node_name) |
| 140 | + root_edge_or_node = self._get_edge_or_node(root_target_node, root_qp, nncf_graph) |
| 141 | + |
| 142 | + for quantizer_id in quantizer_ids: |
| 143 | + if quantizer_id == root_quantizer_id: |
| 144 | + continue |
| 145 | + |
| 146 | + qspec = TorchAOSharedQuantizationSpec(root_edge_or_node) |
| 147 | + qp = quantization_setup.quantization_points[quantizer_id] |
| 148 | + edge_or_node, annotation = self._get_edge_or_node_and_annotation( |
| 149 | + graph, nncf_graph, qp, node_vs_torch_annotation |
| 150 | + ) |
| 151 | + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) |
| 152 | + |
| 153 | + for node, annotation in node_vs_torch_annotation.items(): |
| 154 | + assert QUANT_ANNOTATION_KEY not in node.meta |
| 155 | + node.meta[QUANT_ANNOTATION_KEY] = annotation |
| 156 | + |
| 157 | + @staticmethod |
| 158 | + def _get_unified_scales_root_quantizer_id( |
| 159 | + nncf_graph: NNCFGraph, quantizer_ids: List[int], quantizer_setup: SingleConfigQuantizerSetup |
| 160 | + ) -> int: |
| 161 | + """ |
| 162 | + Identifies the earliest quantizer node ID based on the corresponding `nncf_node.node_id` |
| 163 | + in the given NNCFGraph. This is required by the `_get_obs_or_fq_map` function. |
| 164 | + Refer to: https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/prepare.py#L291 |
| 165 | +
|
| 166 | + :param nncf_graph: The NNCFGraph instance. |
| 167 | + :param quantizer_ids: The list of quantizer IDs to evaluate. |
| 168 | + :param quantizer_setup: The instance of SingleConfigQuantizerSetup. |
| 169 | + :return: The ID of the earliest quantizer node in terms of `nncf_node.node_id`. |
| 170 | + """ |
| 171 | + nncf_node_quantizer_id = None |
| 172 | + root_quantizer_id = None |
| 173 | + for quantizer_id in quantizer_ids: |
| 174 | + target_node_name = quantizer_setup.quantization_points[quantizer_id].insertion_point.target_node_name |
| 175 | + nncf_node = nncf_graph.get_node_by_name(target_node_name) |
| 176 | + if nncf_node_quantizer_id is None or nncf_node.node_id < nncf_node_quantizer_id: |
| 177 | + root_quantizer_id = quantizer_id |
| 178 | + nncf_node_quantizer_id = nncf_node.node_id |
| 179 | + return root_quantizer_id |
| 180 | + |
| 181 | + @staticmethod |
| 182 | + def _get_edge_or_node_and_annotation( |
| 183 | + graph: torch.fx.Graph, |
| 184 | + nncf_graph: NNCFGraph, |
| 185 | + qp: QuantizationPointBase, |
| 186 | + node_vs_torch_annotation: Dict[torch.fx.Node, TorchAOQuantizationAnnotation], |
| 187 | + ) -> Tuple[EdgeOrNode, TorchAOQuantizationAnnotation]: |
| 188 | + """ |
| 189 | + Retrieves the edge or node and its corresponding TorchAOQuantizationAnnotation based on the given graph, |
| 190 | + quantization point, and node-to-annotation mapping. |
| 191 | +
|
| 192 | + :param graph: torch.fx.Graph instance. |
| 193 | + :param nncf_graph: NNCFGraph instance. |
| 194 | + :param qp: QuantizationPointBase instance. |
| 195 | + :param node_vs_torch_annotation: A dictionary mapping torch.fx.GraphNode objects to their respective |
| 196 | + TorchAOQuantizationAnnotations. |
| 197 | + :return: A tuple containing the EdgeOrNode and its associated TorchAOQuantizationAnnotation. |
| 198 | + """ |
| 199 | + target_node = get_graph_node_by_name(graph, qp.insertion_point.target_node_name) |
| 200 | + annotation = node_vs_torch_annotation[target_node] |
| 201 | + edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph) |
| 202 | + return edge_or_node, annotation |
| 203 | + |
| 204 | + @staticmethod |
| 205 | + def _get_edge_or_node(target_node: torch.fx.Node, qp: QuantizationPointBase, nncf_graph: NNCFGraph) -> EdgeOrNode: |
| 206 | + """ |
| 207 | + Returns the edge or node based on the given target node and quantization point. |
| 208 | +
|
| 209 | + :param target_node: Target node instance. |
| 210 | + :param qp: QuantizationPointBase instance. |
| 211 | + :param graph: NNCFGraph instance. |
| 212 | + :return: The corresponding EdgeOrNode derived from the target node and quantization point. |
| 213 | + """ |
| 214 | + ip = qp.insertion_point |
| 215 | + if qp.is_weight_quantization_point(): |
| 216 | + nncf_node = nncf_graph.get_node_by_name(target_node.name) |
| 217 | + weights_ports_ids = get_weight_tensor_port_ids(nncf_node, nncf_graph) |
| 218 | + if len(weights_ports_ids) > 1: |
| 219 | + # TODO(dlyakhov): support quantization for nodes with several weights |
| 220 | + nncf_logger.warning( |
| 221 | + f"Quantization of the weighted node {target_node.name}" |
| 222 | + " is not yet supported by the OpenVINOQuantizer." |
| 223 | + f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." |
| 224 | + f" Quantizable weights are located on ports: {weights_ports_ids}." |
| 225 | + ) |
| 226 | + weight_node = target_node.all_input_nodes[weights_ports_ids[0]] |
| 227 | + return (weight_node, target_node) |
| 228 | + |
| 229 | + if ip.input_port_id is None: |
| 230 | + return target_node |
| 231 | + |
| 232 | + node = target_node.all_input_nodes[ip.input_port_id] |
| 233 | + return (node, target_node) |
| 234 | + |
| 235 | + @staticmethod |
| 236 | + def _fill_torch_ao_annotation( |
| 237 | + edge_or_node: EdgeOrNode, |
| 238 | + qspec: TorchAOQuantizationSpecBase, |
| 239 | + annotation_to_update: TorchAOQuantizationAnnotation, |
| 240 | + ) -> None: |
| 241 | + """ |
| 242 | + Helper method to update the annotation_to_update based on the specified edge_or_node and qspec. |
| 243 | +
|
| 244 | + :param edge_or_node: The target EdgeOrNode to be used for the update. |
| 245 | + :param qspec: An instance of TorchAOQuantizationSpecBase representing the quantization specification to apply. |
| 246 | + :param annotation_to_update: The annotation to update based on the edge_or_node and qspec. |
| 247 | + """ |
| 248 | + if isinstance(edge_or_node, torch.fx.Node): |
| 249 | + annotation_to_update.output_qspec = qspec |
| 250 | + else: |
| 251 | + annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec |
| 252 | + |
| 253 | + @staticmethod |
| 254 | + def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> TorchAOQuantizationSpec: |
| 255 | + """ |
| 256 | + Retrieves the quantization configuration from the given quantization point and |
| 257 | + converts it into a TorchAOQuantizationSpec. |
| 258 | +
|
| 259 | + :param qp: An instance of QuantizationPointBase. |
| 260 | + :return: A TorchAOQuantizationSpec retrieved and converted from the quantization point. |
| 261 | + """ |
| 262 | + # Eps value is copied from nncf/torch/quantization/layers.py |
| 263 | + extra_args = {"eps": 1e-16} |
| 264 | + qconfig = qp.qconfig |
| 265 | + is_weight = qp.is_weight_quantization_point() |
| 266 | + |
| 267 | + if qconfig.per_channel: |
| 268 | + torch_qscheme = ( |
| 269 | + torch.per_channel_symmetric |
| 270 | + if qconfig.mode is QuantizationScheme.SYMMETRIC |
| 271 | + else torch.per_channel_affine |
| 272 | + ) |
| 273 | + else: |
| 274 | + torch_qscheme = ( |
| 275 | + torch.per_tensor_symmetric if qconfig.mode is QuantizationScheme.SYMMETRIC else torch.per_tensor_affine |
| 276 | + ) |
| 277 | + if is_weight: |
| 278 | + observer = PerChannelMinMaxObserver |
| 279 | + quant_min = -128 |
| 280 | + quant_max = 127 |
| 281 | + dtype = torch.int8 |
| 282 | + channel_axis = 0 |
| 283 | + else: |
| 284 | + observer = ( |
| 285 | + HistogramObserver |
| 286 | + if torch_qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] |
| 287 | + else PerChannelMinMaxObserver |
| 288 | + ) |
| 289 | + quant_min = 0 |
| 290 | + quant_max = 255 |
| 291 | + dtype = torch.int8 if qconfig.signedness_to_force else torch.uint8 |
| 292 | + channel_axis = 1 # channel dim for activations |
| 293 | + return TorchAOQuantizationSpec( |
| 294 | + dtype=dtype, |
| 295 | + observer_or_fake_quant_ctr=observer.with_args(**extra_args), |
| 296 | + quant_min=quant_min, |
| 297 | + quant_max=quant_max, |
| 298 | + qscheme=torch_qscheme, |
| 299 | + ch_axis=channel_axis, |
| 300 | + is_dynamic=False, |
| 301 | + ) |
| 302 | + |
| 303 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 304 | + pass |
| 305 | + |
| 306 | + def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 307 | + fold_constant_except_qdq(model) |
| 308 | + return model |
0 commit comments