Skip to content

Commit d4daf6d

Browse files
Comments
1 parent 708aeb7 commit d4daf6d

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

nncf/quantization/algorithms/min_max/torch_backend.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from nncf.torch.quantization.layers import BaseQuantizer
5252
from nncf.torch.quantization.layers import PTQuantizerSpec
5353
from nncf.torch.quantization.layers import get_scale_shape
54-
from nncf.torch.utils import get_weight_nodes_in_inference_grpah
54+
from nncf.torch.utils import get_weight_nodes_in_inference_graph
5555

5656

5757
class PTMinMaxAlgoBackend(MinMaxAlgoBackend):
@@ -342,7 +342,7 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
342342
return set()
343343

344344
def get_weight_nodes(self, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
345-
return get_weight_nodes_in_inference_grpah(inference_nncf_graph, self.mat_mul_metatypes)
345+
return get_weight_nodes_in_inference_graph(inference_nncf_graph, self.mat_mul_metatypes)
346346

347347
def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
348348
return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0

nncf/quantization/algorithms/min_max/torch_fx_backend.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from nncf.torch.quantization.layers import PTQuantizerSpec
5050
from nncf.torch.quantization.layers import get_scale_shape
5151
from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
52-
from nncf.torch.utils import get_weight_nodes_in_inference_grpah
52+
from nncf.torch.utils import get_weight_nodes_in_inference_graph
5353

5454

5555
class FXMinMaxAlgoBackend(MinMaxAlgoBackend):
@@ -306,7 +306,7 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
306306
return set()
307307

308308
def get_weight_nodes(self, inference_nncf_graph: NNCFGraph) -> List[NNCFNode]:
309-
return get_weight_nodes_in_inference_grpah(inference_nncf_graph, self.mat_mul_metatypes)
309+
return get_weight_nodes_in_inference_graph(inference_nncf_graph, self.mat_mul_metatypes)
310310

311311
def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
312312
return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0

nncf/torch/utils.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212
from collections import OrderedDict
1313
from contextlib import contextmanager
14-
from typing import Any, Dict, Generator, List
14+
from typing import Any, Dict, Generator, List, Type
1515

1616
import numpy as np
1717
import torch
@@ -472,7 +472,7 @@ def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
472472
return dtype
473473

474474

475-
def get_weight_nodes_in_inference_grpah(
475+
def get_weight_nodes_in_inference_graph(
476476
inference_nncf_graph: NNCFGraph, mat_mul_metatypes: List[Type[om.PTOperatorMetatype]]
477477
) -> List[NNCFNode]:
478478
"""
@@ -510,6 +510,7 @@ def is_matmul_with_constant_in_inference_graph(node: NNCFNode, inference_nncf_gr
510510

511511
# Inference graph does not contain constants, so
512512
# any missed input edge means it is a constant branch.
513-
return node.metatype in [om.PTMatMulMetatype, om.PTAddmmMetatype] and len(
514-
inference_nncf_graph.get_input_edges(node)
515-
) < len(node.metatype.weight_port_ids)
513+
is_matmul_metatype = node.metatype in [om.PTMatMulMetatype, om.PTAddmmMetatype]
514+
inputs_missed = 1 <= len(inference_nncf_graph.get_input_edges(node)) < len(node.metatype.weight_port_ids)
515+
516+
return is_matmul_metatype and inputs_missed

0 commit comments

Comments
 (0)