Skip to content

Commit 4efc079

Browse files
[PT2] Detect missed inputs for noed based on TensorMeta in arguments (#3360)
### Changes For new tracing detect missed inputs based on count of TensorMeta instance in op inputs. ### Reason for changes Failed on detection missed input for concat nodes.
1 parent cb9e9ac commit 4efc079

File tree

2 files changed

+117
-12
lines changed

2 files changed

+117
-12
lines changed

nncf/torch/graph/graph.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from itertools import chain
1213
from typing import Dict, List, Tuple
1314

1415
import nncf
1516
from nncf.common.graph import NNCFGraph
1617
from nncf.common.graph import NNCFNode
1718
from nncf.common.graph import NNCFNodeName
1819
from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes
20+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
21+
from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta
22+
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes
1923
from nncf.torch.dynamic_graph.scope import Scope
2024
from nncf.torch.graph.transformations.commands import PTTargetPoint
2125

@@ -102,16 +106,33 @@ def get_nodes_with_missed_input_edges(self) -> List[NNCFNode]:
102106
:return: List of NNCFNodes that are identified as disconnected.
103107
"""
104108
input_nodes = set()
105-
for node in self.get_all_nodes():
106-
num_expected_input_edges = None
107-
if hasattr(node.metatype, "num_expected_input_edges"):
108-
num_expected_input_edges = node.metatype.num_expected_input_edges
109-
if node.layer_attributes is not None and isinstance(node.layer_attributes, MultipleInputLayerAttributes):
110-
num_expected_input_edges = node.layer_attributes.num_inputs
111-
if num_expected_input_edges:
112-
input_edges = self.get_input_edges(node)
113-
if len(input_edges) < num_expected_input_edges:
114-
# If node has missed input edges we assume this node is an input node
115-
# that was disconnected from an activation input.
109+
if is_experimental_torch_tracing_enabled():
110+
# Check expected number of input edges by counting TensorMeta in op_args and op_kwargs.
111+
for node in self.get_all_nodes():
112+
input_edges = len(self.get_input_edges(node))
113+
if not isinstance(node.layer_attributes, PT2OpLayerAttributes):
114+
continue
115+
num_expected_input_edges = 0
116+
for val in chain(node.layer_attributes.op_args, node.layer_attributes.op_kwargs.values()):
117+
if isinstance(val, TensorMeta):
118+
num_expected_input_edges += 1
119+
if isinstance(val, (list, tuple)):
120+
num_expected_input_edges += sum(isinstance(v, TensorMeta) for v in val)
121+
if input_edges < num_expected_input_edges:
116122
input_nodes.add(node)
123+
else:
124+
for node in self.get_all_nodes():
125+
num_expected_input_edges = None
126+
if hasattr(node.metatype, "num_expected_input_edges"):
127+
num_expected_input_edges = node.metatype.num_expected_input_edges
128+
if node.layer_attributes is not None and isinstance(
129+
node.layer_attributes, MultipleInputLayerAttributes
130+
):
131+
num_expected_input_edges = node.layer_attributes.num_inputs
132+
if num_expected_input_edges:
133+
input_edges = self.get_input_edges(node)
134+
if len(input_edges) < num_expected_input_edges:
135+
# If node has missed input edges we assume this node is an input node
136+
# that was disconnected from an activation input.
137+
input_nodes.add(node)
117138
return list(input_nodes)

tests/torch2/function_hook/nncf_graph/test_nncf_graph.py

+85-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
from dataclasses import dataclass
1313
from functools import partial
14-
from typing import List, Tuple, Union
14+
from typing import Callable, List, Tuple, Union
1515

1616
import networkx as nx
1717
import pytest
1818
import torch
19+
import torch.nn.functional as F
1920
import torchvision.models as models
2021

2122
from nncf.common.graph.layer_attributes import Dtype
@@ -24,13 +25,18 @@
2425
from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta
2526
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
2627
from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
28+
from nncf.experimental.torch2.function_hook.graph.graph_utils import TensorMeta
29+
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes
2730
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph
2831
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import convert_to_nncf_graph
2932
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_dtype
3033
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_name_of_node
3134
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_node_type
3235
from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook
3336
from nncf.experimental.torch2.function_hook.wrapper import wrap_model
37+
from nncf.torch.graph.graph import PTNNCFGraph
38+
from nncf.torch.graph.operator_metatypes import PTCatMetatype
39+
from nncf.torch.graph.operator_metatypes import PTConv2dMetatype
3440
from tests.cross_fw.shared.paths import TEST_ROOT
3541
from tests.torch2.function_hook import helpers
3642
from tests.torch2.utils import compare_with_reference_file
@@ -155,3 +161,81 @@ def test_model_graph_with_shared_parameters(regen_ref_data):
155161
nx_nncf_graph = nx.nx_pydot.to_pydot(graph)
156162
ref_file = REF_DIR / "model_graph_with_shared_parameters.dot"
157163
compare_with_reference_file(str(nx_nncf_graph), ref_file, regen_ref_data)
164+
165+
166+
def _missed_input_edge_for_conv() -> PTNNCFGraph:
167+
graph = PTNNCFGraph()
168+
graph.add_nncf_node(
169+
node_name="conv",
170+
node_type="conv",
171+
node_metatype=PTConv2dMetatype,
172+
layer_attributes=PT2OpLayerAttributes(
173+
func=F.conv2d,
174+
op_args=(
175+
TensorMeta(shape=(1,), dtype=torch.float),
176+
TensorMeta(shape=(1,), dtype=torch.float),
177+
),
178+
op_kwargs={},
179+
constant_port_ids=set([1]),
180+
),
181+
)
182+
return graph
183+
184+
185+
def _missed_input_edge_for_concat() -> PTNNCFGraph:
186+
graph = PTNNCFGraph()
187+
graph.add_nncf_node(
188+
node_name="concat",
189+
node_type="concat",
190+
node_metatype=PTCatMetatype,
191+
layer_attributes=PT2OpLayerAttributes(
192+
func=torch.concat,
193+
op_args=(
194+
[
195+
TensorMeta(shape=(1,), dtype=torch.float),
196+
TensorMeta(shape=(1,), dtype=torch.float),
197+
]
198+
),
199+
op_kwargs={},
200+
constant_port_ids=set(),
201+
),
202+
)
203+
return graph
204+
205+
206+
def _no_missed_input_edge_for_conv() -> PTNNCFGraph:
207+
graph = PTNNCFGraph()
208+
node_input = graph.add_nncf_node("input", "input", None, None)
209+
node_weight = graph.add_nncf_node("weight", "weight", None, None)
210+
node_conv = graph.add_nncf_node(
211+
node_name="conv",
212+
node_type="conv",
213+
node_metatype=PTConv2dMetatype,
214+
layer_attributes=PT2OpLayerAttributes(
215+
func=F.conv2d,
216+
op_args=(
217+
TensorMeta(shape=(1, 1, 1, 1), dtype=torch.float),
218+
TensorMeta(shape=(1, 1, 1, 1), dtype=torch.float),
219+
),
220+
op_kwargs={},
221+
constant_port_ids=set([1]),
222+
),
223+
)
224+
graph.add_edge_between_nncf_nodes(node_input.node_id, node_conv.node_id, (1,), 0, 0, Dtype.FLOAT)
225+
graph.add_edge_between_nncf_nodes(node_weight.node_id, node_conv.node_id, (1,), 1, 0, Dtype.FLOAT)
226+
return graph
227+
228+
229+
@pytest.mark.parametrize(
230+
"graph_builder, ref",
231+
(
232+
(_missed_input_edge_for_conv, ["conv"]),
233+
(_missed_input_edge_for_concat, ["concat"]),
234+
(_no_missed_input_edge_for_conv, []),
235+
),
236+
)
237+
def test_get_nodes_with_missed_input_edges(graph_builder: Callable[[], PTNNCFGraph], ref: List[str]):
238+
graph = graph_builder()
239+
ret = graph.get_nodes_with_missed_input_edges()
240+
ret_names = [node.node_name for node in ret]
241+
assert ret_names == ref

0 commit comments

Comments
 (0)