|
15 | 15 | from nncf.common.utils.backend import BackendType
|
16 | 16 | from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend
|
17 | 17 | from nncf.torch.graph.graph import PTNNCFGraph
|
| 18 | +from tests.cross_fw.test_templates.models import NNCFGraphConstantBranchWithWeightedNode |
| 19 | +from tests.cross_fw.test_templates.models import NNCFGraphModelWithEmbeddingsConstantPath |
18 | 20 | from tests.cross_fw.test_templates.models import NNCFGraphToTest
|
19 | 21 | from tests.cross_fw.test_templates.models import NNCFGraphToTestDepthwiseConv
|
20 | 22 | from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation
|
@@ -54,3 +56,30 @@ def transformer_nncf_graph(self) -> NNCFGraphToTest:
|
54 | 56 | transpose_metatype=om.PTTransposeMetatype,
|
55 | 57 | nncf_graph_cls=PTNNCFGraph,
|
56 | 58 | )
|
| 59 | + |
| 60 | + @pytest.fixture |
| 61 | + def embedding_nncf_graph_shape_of(self) -> NNCFGraphToTest: |
| 62 | + return None |
| 63 | + |
| 64 | + @pytest.mark.skip("Torch2 does not have shape of subgraphs") |
| 65 | + def test_embedding_model_qconfig_shape_of(self, embedding_nncf_graph_shape_of): |
| 66 | + pass |
| 67 | + |
| 68 | + @pytest.fixture |
| 69 | + def embedding_nncf_graph_constant_path(self) -> NNCFGraphToTest: |
| 70 | + return NNCFGraphModelWithEmbeddingsConstantPath( |
| 71 | + const_metatype=om.PTConstNoopMetatype, |
| 72 | + embedding_metatype=om.PTModuleEmbeddingMetatype, |
| 73 | + conv_metatype=om.PTModuleConv2dMetatype, |
| 74 | + add_metatype=om.PTAddMetatype, |
| 75 | + nncf_graph_cls=PTNNCFGraph, |
| 76 | + ) |
| 77 | + |
| 78 | + @pytest.fixture |
| 79 | + def constant_branch_nncf_graph(self) -> NNCFGraphToTest: |
| 80 | + return NNCFGraphConstantBranchWithWeightedNode( |
| 81 | + const_metatype=om.PTConstNoopMetatype, |
| 82 | + conv_metatype=om.PTModuleConv2dMetatype, |
| 83 | + add_metatype=om.PTAddMetatype, |
| 84 | + nncf_graph_cls=PTNNCFGraph, |
| 85 | + ) |
0 commit comments