Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3a566ba

Browse files
committedMar 21, 2025·
Shape dim could be only torch.SymInt
1 parent 1662f2a commit 3a566ba

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed
 

‎nncf/experimental/torch/fx/nncf_graph_builder.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,7 @@ def get_edge_params(
198198
else:
199199
tensor = source_node.meta["val"]
200200
if isinstance(tensor, torch.Tensor):
201-
tensor_shape = tuple(
202-
-1 if isinstance(i, GraphConverter.TORCH_SYMBOLIC_TYPES) else i for i in tensor.shape
203-
)
201+
tensor_shape = tuple(-1 if isinstance(i, torch.SymInt) else i for i in tensor.shape)
204202
elif isinstance(tensor, GraphConverter.TORCH_SYMBOLIC_TYPES):
205203
tensor_shape = (-1,)
206204

‎tests/torch2/function_hook/quantization/test_quantizer_config.py

+29
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from nncf.common.utils.backend import BackendType
1616
from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend
1717
from nncf.torch.graph.graph import PTNNCFGraph
18+
from tests.cross_fw.test_templates.models import NNCFGraphConstantBranchWithWeightedNode
19+
from tests.cross_fw.test_templates.models import NNCFGraphModelWithEmbeddingsConstantPath
1820
from tests.cross_fw.test_templates.models import NNCFGraphToTest
1921
from tests.cross_fw.test_templates.models import NNCFGraphToTestDepthwiseConv
2022
from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation
@@ -54,3 +56,30 @@ def transformer_nncf_graph(self) -> NNCFGraphToTest:
5456
transpose_metatype=om.PTTransposeMetatype,
5557
nncf_graph_cls=PTNNCFGraph,
5658
)
59+
60+
@pytest.fixture
61+
def embedding_nncf_graph_shape_of(self) -> NNCFGraphToTest:
62+
return None
63+
64+
@pytest.mark.skip("Torch2 does not have shape of subgraphs")
65+
def test_embedding_model_qconfig_shape_of(self, embedding_nncf_graph_shape_of):
66+
pass
67+
68+
@pytest.fixture
69+
def embedding_nncf_graph_constant_path(self) -> NNCFGraphToTest:
70+
return NNCFGraphModelWithEmbeddingsConstantPath(
71+
const_metatype=om.PTConstNoopMetatype,
72+
embedding_metatype=om.PTModuleEmbeddingMetatype,
73+
conv_metatype=om.PTModuleConv2dMetatype,
74+
add_metatype=om.PTAddMetatype,
75+
nncf_graph_cls=PTNNCFGraph,
76+
)
77+
78+
@pytest.fixture
79+
def constant_branch_nncf_graph(self) -> NNCFGraphToTest:
80+
return NNCFGraphConstantBranchWithWeightedNode(
81+
const_metatype=om.PTConstNoopMetatype,
82+
conv_metatype=om.PTModuleConv2dMetatype,
83+
add_metatype=om.PTAddMetatype,
84+
nncf_graph_cls=PTNNCFGraph,
85+
)

0 commit comments

Comments
 (0)
Please sign in to comment.