Skip to content

Commit 0d673b5

Browse files
Use both nncf_graph and inference_nncf_graph
1 parent d4daf6d commit 0d673b5

File tree

7 files changed

+36
-32
lines changed

7 files changed

+36
-32
lines changed

nncf/quantization/algorithms/min_max/backend.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,12 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
297297
"""
298298

299299
@abstractmethod
300-
def get_weight_nodes(self, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
300+
def get_weight_nodes(self, nncf_graph: NNCFGraph, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
301301
"""
302302
Returns nodes that have weights.
303303
304+
:param nncf_graph: Instance of original NNCFGraph,
305+
which contains shape of and constant subgraphs.
304306
:param inference_nncf_graph: Instance of inference NNCFGraph,
305307
which does not contain shape of and constant subgraphs.
306308
:return: All nodes with weights.

nncf/quantization/algorithms/min_max/onnx_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
217217
def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
218218
return set()
219219

220-
def get_weight_nodes(self, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
220+
def get_weight_nodes(self, nncf_grpah: NNCFGraph, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
221221
return [node for node in inference_nncf_graph.get_all_nodes() if node.layer_attributes.has_weight()]
222222

223223
@staticmethod

nncf/quantization/algorithms/min_max/openvino_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
215215
ignored_names.add(node.node_name)
216216
return ignored_names
217217

218-
def get_weight_nodes(self, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
218+
def get_weight_nodes(self, nncf_grpah: NNCFGraph, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
219219
return [
220220
node
221221
for node in inference_nncf_graph.get_all_nodes()

nncf/quantization/algorithms/min_max/torch_backend.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from nncf.torch.graph.graph import PTNNCFGraph
3838
from nncf.torch.graph.graph import PTTargetPoint
3939
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
40+
from nncf.torch.graph.operator_metatypes import MATMUL_METATYPES
4041
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
4142
from nncf.torch.graph.transformations.command_creation import create_shared_quantizer_insertion_command
4243
from nncf.torch.graph.transformations.commands import PTInsertionCommand
@@ -51,7 +52,8 @@
5152
from nncf.torch.quantization.layers import BaseQuantizer
5253
from nncf.torch.quantization.layers import PTQuantizerSpec
5354
from nncf.torch.quantization.layers import get_scale_shape
54-
from nncf.torch.utils import get_weight_nodes_in_inference_graph
55+
from nncf.torch.utils import get_weight_nodes
56+
from nncf.torch.utils import is_matmul_with_constant
5557

5658

5759
class PTMinMaxAlgoBackend(MinMaxAlgoBackend):
@@ -66,7 +68,7 @@ def preserved_metatypes(self) -> List[OperatorMetatype]:
6668

6769
@property
6870
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
69-
return [om.PTLinearMetatype, om.PTMatMulMetatype, om.PTAddmmMetatype]
71+
return MATMUL_METATYPES
7072

7173
@property
7274
def post_processing_metatypes(self) -> List[OperatorMetatype]:
@@ -341,8 +343,8 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
341343
def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
342344
return set()
343345

344-
def get_weight_nodes(self, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
345-
return get_weight_nodes_in_inference_graph(inference_nncf_graph, self.mat_mul_metatypes)
346+
def get_weight_nodes(self, nncf_grpah: NNCFGraph, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
347+
return get_weight_nodes(inference_nncf_graph, self.mat_mul_metatypes)
346348

347349
def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
348-
return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0
350+
return is_matmul_with_constant(node, nncf_graph)

nncf/quantization/algorithms/min_max/torch_fx_backend.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from nncf.torch.graph.graph import PTNNCFGraph
3939
from nncf.torch.graph.graph import PTTargetPoint
4040
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
41+
from nncf.torch.graph.operator_metatypes import MATMUL_METATYPES
4142
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
4243
from nncf.torch.hardware.config import PTHWConfig
4344
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
@@ -49,7 +50,8 @@
4950
from nncf.torch.quantization.layers import PTQuantizerSpec
5051
from nncf.torch.quantization.layers import get_scale_shape
5152
from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
52-
from nncf.torch.utils import get_weight_nodes_in_inference_graph
53+
from nncf.torch.utils import get_weight_nodes
54+
from nncf.torch.utils import is_matmul_with_constant
5355

5456

5557
class FXMinMaxAlgoBackend(MinMaxAlgoBackend):
@@ -59,7 +61,7 @@ def preserved_metatypes(self) -> List[OperatorMetatype]:
5961

6062
@property
6163
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
62-
return [om.PTLinearMetatype, om.PTMatMulMetatype]
64+
return MATMUL_METATYPES
6365

6466
@property
6567
def post_processing_metatypes(self) -> List[OperatorMetatype]:
@@ -305,8 +307,8 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
305307
def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
306308
return set()
307309

308-
def get_weight_nodes(self, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
309-
return get_weight_nodes_in_inference_graph(inference_nncf_graph, self.mat_mul_metatypes)
310+
def get_weight_nodes(self, nncf_grpah: NNCFGraph, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
311+
return get_weight_nodes(inference_nncf_graph, self.mat_mul_metatypes)
310312

311313
def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
312-
return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0
314+
return is_matmul_with_constant(node, nncf_graph)

nncf/torch/graph/operator_metatypes.py

+2
Original file line numberDiff line numberDiff line change
@@ -1219,3 +1219,5 @@ def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
12191219
PTModuleEmbeddingBagMetatype,
12201220
PTModuleEmbeddingMetatype,
12211221
]
1222+
1223+
MATMUL_METATYPES = [PTLinearMetatype, PTMatMulMetatype, PTAddmmMetatype]

nncf/torch/utils.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212
from collections import OrderedDict
1313
from contextlib import contextmanager
14-
from typing import Any, Dict, Generator, List, Type
14+
from typing import Any, Dict, Generator, List
1515

1616
import numpy as np
1717
import torch
@@ -31,7 +31,9 @@
3131
from nncf.torch.dynamic_graph.scope import Scope
3232
from nncf.torch.dynamic_graph.scope import ScopeElement
3333
from nncf.torch.dynamic_graph.trace_tensor import TracedTensorMixin
34+
from nncf.torch.graph.operator_metatypes import MATMUL_METATYPES
3435
from nncf.torch.layer_utils import _NNCFModuleMixin
36+
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
3537
from nncf.torch.structures import ExecutionParameters
3638

3739

@@ -472,14 +474,18 @@ def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
472474
return dtype
473475

474476

475-
def get_weight_nodes_in_inference_graph(
476-
inference_nncf_graph: NNCFGraph, mat_mul_metatypes: List[Type[om.PTOperatorMetatype]]
477+
def get_weight_nodes(
478+
nncf_graph: NNCFGraph,
479+
inference_nncf_graph: NNCFGraph,
477480
) -> List[NNCFNode]:
478481
"""
479482
Returns nodes that have weights.
480483
481484
:param nncf_graph: Instance of inference NNCFGraph,
485+
which contains shape of and constant subgraphs.
486+
:param inference_nncf_graph: Instance of inference NNCFGraph,
482487
which does not contain shape of and constant subgraphs.
488+
483489
:return: All nodes with weights.
484490
"""
485491
weight_nodes_candidates = [
@@ -489,28 +495,18 @@ def get_weight_nodes_in_inference_graph(
489495
]
490496
weight_nodes = []
491497
for node in weight_nodes_candidates:
492-
if node.metatype in mat_mul_metatypes and not is_matmul_with_constant_in_inference_graph(
493-
node, inference_nncf_graph
494-
):
495-
continue
496-
weight_nodes.append(node)
498+
if is_matmul_with_constant(node, nncf_graph):
499+
weight_nodes.append(node)
497500
return weight_nodes
498501

499502

500-
def is_matmul_with_constant_in_inference_graph(node: NNCFNode, inference_nncf_graph: NNCFGraph) -> bool:
503+
def is_matmul_with_constant(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
501504
"""
502505
Determines whether the given node in the NNCF graph represents a matmul with a constant input.
503506
504507
:param node: A NNCFNode instance.
505-
:param inference_nncf_graph: An inference NNCFGraph instance.
508+
:param nncf_graph: Instance of inference NNCFGraph,
509+
which contains shape of and constant subgraphs.
506510
:return: True if given node is a matmul with a constant input, False otherwise.
507511
"""
508-
if node.metatype == om.PTLinearMetatype:
509-
return True
510-
511-
# Inference graph does not contain constants, so
512-
# any missed input edge means it is a constant branch.
513-
is_matmul_metatype = node.metatype in [om.PTMatMulMetatype, om.PTAddmmMetatype]
514-
inputs_missed = 1 <= len(inference_nncf_graph.get_input_edges(node)) < len(node.metatype.weight_port_ids)
515-
516-
return is_matmul_metatype and inputs_missed
512+
return node.metatype in MATMUL_METATYPES and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0

0 commit comments

Comments
 (0)