|
11 | 11 |
|
12 | 12 | from dataclasses import dataclass
|
13 | 13 | from functools import partial
|
14 |
| -from typing import List, Tuple, Union |
| 14 | +from typing import Callable, List, Tuple, Union |
15 | 15 |
|
16 | 16 | import networkx as nx
|
17 | 17 | import pytest
|
18 | 18 | import torch
|
| 19 | +import torch.nn.functional as F |
19 | 20 | import torchvision.models as models
|
20 | 21 |
|
21 | 22 | from nncf.common.graph.layer_attributes import Dtype
|
|
24 | 25 | from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta
|
25 | 26 | from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
|
26 | 27 | from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
|
| 28 | +from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta |
| 29 | +from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes |
27 | 30 | from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph
|
28 | 31 | from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import convert_to_nncf_graph
|
29 | 32 | from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_dtype
|
30 | 33 | from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_name_of_node
|
31 | 34 | from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_node_type
|
32 | 35 | from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook
|
33 | 36 | from nncf.experimental.torch2.function_hook.wrapper import wrap_model
|
| 37 | +from nncf.torch.graph.graph import PTNNCFGraph |
| 38 | +from nncf.torch.graph.operator_metatypes import PTCatMetatype |
| 39 | +from nncf.torch.graph.operator_metatypes import PTConv2dMetatype |
34 | 40 | from tests.cross_fw.shared.paths import TEST_ROOT
|
35 | 41 | from tests.torch2.function_hook import helpers
|
36 | 42 | from tests.torch2.utils import compare_with_reference_file
|
@@ -155,3 +161,81 @@ def test_model_graph_with_shared_parameters(regen_ref_data):
|
155 | 161 | nx_nncf_graph = nx.nx_pydot.to_pydot(graph)
|
156 | 162 | ref_file = REF_DIR / "model_graph_with_shared_parameters.dot"
|
157 | 163 | compare_with_reference_file(str(nx_nncf_graph), ref_file, regen_ref_data)
|
| 164 | + |
| 165 | + |
| 166 | +def _missed_input_edge_for_conv() -> PTNNCFGraph: |
| 167 | + graph = PTNNCFGraph() |
| 168 | + graph.add_nncf_node( |
| 169 | + node_name="conv", |
| 170 | + node_type="conv", |
| 171 | + node_metatype=PTConv2dMetatype, |
| 172 | + layer_attributes=PT2OpLayerAttributes( |
| 173 | + func=F.conv2d, |
| 174 | + op_args=( |
| 175 | + TensorMeta(shape=(1,), dtype=torch.float), |
| 176 | + TensorMeta(shape=(1,), dtype=torch.float), |
| 177 | + ), |
| 178 | + op_kwargs={}, |
| 179 | + constant_port_ids=set([1]), |
| 180 | + ), |
| 181 | + ) |
| 182 | + return graph |
| 183 | + |
| 184 | + |
| 185 | +def _missed_input_edge_for_concat() -> PTNNCFGraph: |
| 186 | + graph = PTNNCFGraph() |
| 187 | + graph.add_nncf_node( |
| 188 | + node_name="concat", |
| 189 | + node_type="concat", |
| 190 | + node_metatype=PTCatMetatype, |
| 191 | + layer_attributes=PT2OpLayerAttributes( |
| 192 | + func=torch.concat, |
| 193 | + op_args=( |
| 194 | + [ |
| 195 | + TensorMeta(shape=(1,), dtype=torch.float), |
| 196 | + TensorMeta(shape=(1,), dtype=torch.float), |
| 197 | + ] |
| 198 | + ), |
| 199 | + op_kwargs={}, |
| 200 | + constant_port_ids=set(), |
| 201 | + ), |
| 202 | + ) |
| 203 | + return graph |
| 204 | + |
| 205 | + |
| 206 | +def _no_missed_input_edge_for_conv() -> PTNNCFGraph: |
| 207 | + graph = PTNNCFGraph() |
| 208 | + node_input = graph.add_nncf_node("input", "input", None, None) |
| 209 | + node_weight = graph.add_nncf_node("weight", "weight", None, None) |
| 210 | + node_conv = graph.add_nncf_node( |
| 211 | + node_name="conv", |
| 212 | + node_type="conv", |
| 213 | + node_metatype=PTConv2dMetatype, |
| 214 | + layer_attributes=PT2OpLayerAttributes( |
| 215 | + func=F.conv2d, |
| 216 | + op_args=( |
| 217 | + TensorMeta(shape=(1, 1, 1, 1), dtype=torch.float), |
| 218 | + TensorMeta(shape=(1, 1, 1, 1), dtype=torch.float), |
| 219 | + ), |
| 220 | + op_kwargs={}, |
| 221 | + constant_port_ids=set([1]), |
| 222 | + ), |
| 223 | + ) |
| 224 | + graph.add_edge_between_nncf_nodes(node_input.node_id, node_conv.node_id, (1,), 0, 0, Dtype.FLOAT) |
| 225 | + graph.add_edge_between_nncf_nodes(node_weight.node_id, node_conv.node_id, (1,), 1, 0, Dtype.FLOAT) |
| 226 | + return graph |
| 227 | + |
| 228 | + |
| 229 | +@pytest.mark.parametrize( |
| 230 | + "graph_builder, ref", |
| 231 | + ( |
| 232 | + (_missed_input_edge_for_conv, ["conv"]), |
| 233 | + (_missed_input_edge_for_concat, ["concat"]), |
| 234 | + (_no_missed_input_edge_for_conv, []), |
| 235 | + ), |
| 236 | +) |
| 237 | +def test_get_nodes_with_missed_input_edges(graph_builder: Callable[[], PTNNCFGraph], ref: List[str]): |
| 238 | + graph = graph_builder() |
| 239 | + ret = graph.get_nodes_with_missed_input_edges() |
| 240 | + ret_names = [node.node_name for node in ret] |
| 241 | + assert ret_names == ref |
0 commit comments