Skip to content

Commit 17f799e

Browse files
[NNCF][BC] Add test case with depthwise/transpose convolutions to template and backend specific tests (#3004)
### Changes - Added simple Depthwise and Transpose Convolution models in `tests/cross_fw/test_templates/helpers.py`. - Updated `map_references` for Torch FX backend to assign the right reference node names for model classes Depthwise and Transpose Convolutions. - Added `ONNXConvolutionTransposeMetatype` into the list of OPERATIONS_WITH_BIAS for ONNX. - Added the missing target for Transpose Conv in `transformations.py` for Torch FX in the function `_is_conv`. - Added `OVConvolutionBackpropDataMetatype` into the list of OPERATIONS_WITH_BIAS for OpenVino backend - Replaced the unet graph in quantized reference graphs for FX backend. ### Extra Changes - [X] Update and finalize the right changes to make in `tests/openvino/native/test_bias_correction.py` for Transpose Convolution node name and accommodate for the changes in the name after each run. ### Closes issue #2916 --------- Co-authored-by: dlyakhov <daniil.lyakhov@intel.com>
1 parent cb0fe0d commit 17f799e

File tree

9 files changed

+122
-59
lines changed

9 files changed

+122
-59
lines changed

nncf/experimental/torch/fx/transformations.py

+1
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ def _is_conv(n: torch.fx.Node):
569569
return n.op == "call_function" and n.target in (
570570
torch.ops.aten.conv1d.default,
571571
torch.ops.aten.conv2d.default,
572+
torch.ops.aten.conv_transpose2d.input,
572573
)
573574

574575

nncf/onnx/graph/metatypes/groups.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,11 @@
142142
# TODO: Need to add MatMul with the separate bias support (CVS-135433)
143143
]
144144

145-
OPERATIONS_WITH_BIAS = [*OPERATIONS_WITH_BIAS_REDUCED, onnx_metatypes.ONNXDepthwiseConvolutionMetatype]
145+
OPERATIONS_WITH_BIAS = [
146+
*OPERATIONS_WITH_BIAS_REDUCED,
147+
onnx_metatypes.ONNXDepthwiseConvolutionMetatype,
148+
onnx_metatypes.ONNXConvolutionTransposeMetatype,
149+
]
146150

147151

148152
QUANTIZE_DEQUANTIZE_OPERATIONS = [

nncf/openvino/graph/metatypes/groups.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@
192192
ov_metatypes.OVMatMulMetatype,
193193
]
194194

195-
OPERATIONS_WITH_BIAS = [*OPERATIONS_WITH_BIAS_REDUCED, ov_metatypes.OVDepthwiseConvolutionMetatype]
195+
OPERATIONS_WITH_BIAS = [
196+
*OPERATIONS_WITH_BIAS_REDUCED,
197+
ov_metatypes.OVDepthwiseConvolutionMetatype,
198+
ov_metatypes.OVConvolutionBackpropDataMetatype,
199+
]
196200

197201
CONV_OPERATIONS = [
198202
ov_metatypes.OVConvolutionMetatype,

tests/cross_fw/test_templates/helpers.py

+30
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from nncf import Dataset
2020
from tests.torch.helpers import create_bn
2121
from tests.torch.helpers import create_conv
22+
from tests.torch.helpers import create_depthwise_conv
23+
from tests.torch.helpers import create_transpose_conv
2224
from tests.torch.helpers import set_torch_seed
2325

2426
TTensor = TypeVar("TTensor")
@@ -436,3 +438,31 @@ def __init__(self):
436438

437439
def forward(self, query, key, value):
438440
return nn.functional.scaled_dot_product_attention(query, key, value)
441+
442+
443+
class DepthwiseConvTestModel(nn.Module):
444+
INPUT_SIZE = [1, 2, 4, 4]
445+
446+
def __init__(self):
447+
super().__init__()
448+
with set_torch_seed():
449+
self.conv = create_depthwise_conv(2, 1, 1, 1)
450+
self.conv.weight.data = torch.randn([2, 1, 1, 1])
451+
self.conv.bias.data = torch.randn([2])
452+
453+
def forward(self, x):
454+
return self.conv(x)
455+
456+
457+
class TransposeConvTestModel(nn.Module):
458+
INPUT_SIZE = [1, 1, 3, 3]
459+
460+
def __init__(self):
461+
super().__init__()
462+
with set_torch_seed():
463+
self.conv = create_transpose_conv(1, 2, 2, 1, 1, 2)
464+
self.conv.weight.data = torch.randn([1, 2, 2, 2])
465+
self.conv.bias.data = torch.randn([2])
466+
467+
def forward(self, x):
468+
return self.conv(x)

tests/cross_fw/test_templates/test_bias_correction.py

+4
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
from nncf.quantization.algorithms.bias_correction.backend import BiasCorrectionAlgoBackend
2323
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
2424
from tests.cross_fw.test_templates.helpers import ConvTestModel
25+
from tests.cross_fw.test_templates.helpers import DepthwiseConvTestModel
2526
from tests.cross_fw.test_templates.helpers import MultipleConvTestModel
2627
from tests.cross_fw.test_templates.helpers import SplittedModel
2728
from tests.cross_fw.test_templates.helpers import StaticDatasetMock
29+
from tests.cross_fw.test_templates.helpers import TransposeConvTestModel
2830

2931
TModel = TypeVar("TModel")
3032
TTensor = TypeVar("TTensor")
@@ -139,6 +141,8 @@ def quantized_test_model(self, tmpdir) -> TModel:
139141
},
140142
),
141143
(ConvTestModel, {"/conv/Conv": [0.11085186, 1.0017344]}),
144+
(DepthwiseConvTestModel, {"/conv/Conv": [-1.1229, -0.1863]}),
145+
(TransposeConvTestModel, {"/conv/ConvTranspose": [0.66797173, -0.7070703]}),
142146
),
143147
)
144148
def test_update_bias(self, model_cls, ref_biases, tmpdir):

tests/onnx/quantization/test_bias_correction.py

+4
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
from nncf.onnx.graph.node_utils import get_bias_value
2323
from nncf.quantization.algorithms.bias_correction.onnx_backend import ONNXBiasCorrectionAlgoBackend
2424
from tests.cross_fw.test_templates.helpers import ConvTestModel
25+
from tests.cross_fw.test_templates.helpers import DepthwiseConvTestModel
2526
from tests.cross_fw.test_templates.helpers import MultipleConvTestModel
2627
from tests.cross_fw.test_templates.helpers import SplittedModel
28+
from tests.cross_fw.test_templates.helpers import TransposeConvTestModel
2729
from tests.cross_fw.test_templates.test_bias_correction import TemplateTestBCAlgorithm
2830
from tests.onnx.quantization.common import compare_nncf_graph
2931

@@ -211,6 +213,8 @@ def test__get_subgraph_data_for_node(self, quantized_test_model, layer_name, ref
211213
},
212214
),
213215
(ConvTestModel, {("/conv/Conv", 0): ("nncf_model_input_0", 0)}),
216+
(DepthwiseConvTestModel, {("/conv/Conv", 0): ("nncf_model_input_0", 0)}),
217+
(TransposeConvTestModel, {("/conv/ConvTranspose", 0): ("nncf_model_input_0", 0)}),
214218
),
215219
)
216220
def test_verify_collected_stat_inputs_map(self, model_cls, ref_stat_inputs_map, tmpdir):

tests/openvino/native/test_bias_correction.py

+10
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@
2222
from nncf.openvino.graph.node_utils import get_bias_value
2323
from nncf.quantization.algorithms.bias_correction.openvino_backend import OVBiasCorrectionAlgoBackend
2424
from tests.cross_fw.test_templates.helpers import ConvTestModel
25+
from tests.cross_fw.test_templates.helpers import DepthwiseConvTestModel
2526
from tests.cross_fw.test_templates.helpers import MultipleConvTestModel
2627
from tests.cross_fw.test_templates.helpers import SplittedModel
28+
from tests.cross_fw.test_templates.helpers import TransposeConvTestModel
2729
from tests.cross_fw.test_templates.test_bias_correction import TemplateTestBCAlgorithm
2830
from tests.openvino.native.common import compare_nncf_graphs
2931

3032

3133
class TestOVBCAlgorithm(TemplateTestBCAlgorithm):
34+
TRANSPOSE_CONV_NAME = "/conv/ConvTranspose/WithoutBiases"
35+
3236
@staticmethod
3337
def list_to_backend_type(data: List) -> np.ndarray:
3438
return np.array(data)
@@ -42,6 +46,10 @@ def backend_specific_model(model: torch.nn.Module, tmp_dir: str):
4246
onnx_path = f"{tmp_dir}/model.onnx"
4347
torch.onnx.export(model, torch.rand(model.INPUT_SIZE), onnx_path, opset_version=13, input_names=["input.1"])
4448
ov_model = ov.convert_model(onnx_path)
49+
if isinstance(model, TransposeConvTestModel):
50+
for node in ov_model.get_ops():
51+
if node.get_type_name() == "ConvolutionBackpropData":
52+
node.set_friendly_name(TestOVBCAlgorithm.TRANSPOSE_CONV_NAME)
4553
return ov_model
4654

4755
@staticmethod
@@ -206,6 +214,8 @@ def test__get_subgraph_data_for_node(self, quantized_test_model, layer_name, ref
206214
},
207215
),
208216
(ConvTestModel, {("/conv/Conv/WithoutBiases", 0): ("input.1", 0)}),
217+
(DepthwiseConvTestModel, {("/conv/Conv/WithoutBiases", 0): ("input.1", 0)}),
218+
(TransposeConvTestModel, {(TRANSPOSE_CONV_NAME, 0): ("input.1", 0)}),
209219
),
210220
)
211221
def test_verify_collected_stat_inputs_map(self, model_cls, ref_stat_inputs_map, tmpdir):

0 commit comments

Comments
 (0)