Skip to content

Commit 77d2556

Browse files
[FX] Dynamic Shapes Support (#3225)
### Changes Modify NNCF Graph Builder for FX backend to correctly get and insert the dynamic shapes into NNCFGraph ### Reason for changes To support quantization of Torch FX models exported with dynamic shapes ### Tests test is added to `tests/torch/fx/test_models.py` in test_quantized_models(). Currently only the synthetic transformer is tested because torch.export.dynamic_shapes.Dim.DYNAMIC is not supported in pytorch but is supported in upcoming releases. https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes `test_dynamic_edge()` is also added in `tests/torch/fx/test_models.py` to check that the tensor shape in NNCF Graph edge has values only of type int or str and not SymInt. --------- Co-authored-by: Alexander Dokuchaev <alexander.dokuchaev@intel.com>
1 parent 1762c5c commit 77d2556

17 files changed

+20845
-10
lines changed

nncf/experimental/torch/fx/nncf_graph_builder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def get_edge_params(
196196
else:
197197
tensor = source_node.meta["val"]
198198
if isinstance(tensor, torch.Tensor):
199-
tensor_shape = tuple(tensor.shape)
199+
tensor_shape = tuple(-1 if isinstance(i, torch.SymInt) else i for i in tensor.shape)
200+
elif isinstance(tensor, torch.SymInt):
201+
tensor_shape = (-1,)
200202

201203
if tensor_shape is None:
202204
# TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes.

tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/mobilenet_v3_small.dot

+1,110
Large diffs are not rendered by default.

tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/resnet18.dot

+497
Large diffs are not rendered by default.

tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/swin_v2_s.dot

+5,924
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
strict digraph {
2+
"0 wte_weight" [id=0, type=get_attr];
3+
"1 linear_bias" [id=1, type=get_attr];
4+
"2 lm_head_bias" [id=2, type=get_attr];
5+
"3 input_ids" [id=3, type=input];
6+
"4 embedding" [id=4, type=embedding];
7+
"5 embedding_0_0_nncf_smooth_quant_0" [id=5, type=call_module];
8+
"6 quantize_per_tensor_default" [id=6, type=quantize_per_tensor];
9+
"7 dequantize_per_tensor_default" [id=7, type=dequantize_per_tensor];
10+
"8 scale_updated_constant0" [id=8, type=get_attr];
11+
"9 compressed_weight_updated_constant0" [id=9, type=get_attr];
12+
"10 mul_tensor" [id=10, type=mul];
13+
"11 zero_point_updated_constant0" [id=11, type=get_attr];
14+
"12 sub_tensor" [id=12, type=sub];
15+
"13 linear" [id=13, type=linear];
16+
"14 linear_0_0_nncf_smooth_quant_0" [id=14, type=call_module];
17+
"15 quantize_per_tensor_default_1" [id=15, type=quantize_per_tensor];
18+
"16 dequantize_per_tensor_default_1" [id=16, type=dequantize_per_tensor];
19+
"17 scale_updated_constant1" [id=17, type=get_attr];
20+
"18 compressed_weight_updated_constant1" [id=18, type=get_attr];
21+
"19 mul_tensor_1" [id=19, type=mul];
22+
"20 zero_point_updated_constant1" [id=20, type=get_attr];
23+
"21 sub_tensor_1" [id=21, type=sub];
24+
"22 linear_1" [id=22, type=linear];
25+
"23 output" [id=23, type=output];
26+
"0 wte_weight" -> "4 embedding" [label="(10, 5)", style=solid];
27+
"1 linear_bias" -> "13 linear" [label="(5,)", style=solid];
28+
"2 lm_head_bias" -> "22 linear_1" [label="(10,)", style=solid];
29+
"3 input_ids" -> "4 embedding" [label="(-1,)", style=solid];
30+
"4 embedding" -> "5 embedding_0_0_nncf_smooth_quant_0" [label="(-1, 5)", style=solid];
31+
"5 embedding_0_0_nncf_smooth_quant_0" -> "6 quantize_per_tensor_default" [label="(-1, 5)", style=solid];
32+
"6 quantize_per_tensor_default" -> "7 dequantize_per_tensor_default" [label="(-1, 5)", style=solid];
33+
"7 dequantize_per_tensor_default" -> "13 linear" [label="(-1, 5)", style=solid];
34+
"8 scale_updated_constant0" -> "10 mul_tensor" [label="(5, 1)", style=solid];
35+
"9 compressed_weight_updated_constant0" -> "10 mul_tensor" [label="(5, 5)", style=solid];
36+
"10 mul_tensor" -> "12 sub_tensor" [label="(5, 5)", style=solid];
37+
"11 zero_point_updated_constant0" -> "12 sub_tensor" [label="(5, 1)", style=solid];
38+
"12 sub_tensor" -> "13 linear" [label="(5, 5)", style=solid];
39+
"13 linear" -> "14 linear_0_0_nncf_smooth_quant_0" [label="(-1, 5)", style=solid];
40+
"14 linear_0_0_nncf_smooth_quant_0" -> "15 quantize_per_tensor_default_1" [label="(-1, 5)", style=solid];
41+
"15 quantize_per_tensor_default_1" -> "16 dequantize_per_tensor_default_1" [label="(-1, 5)", style=solid];
42+
"16 dequantize_per_tensor_default_1" -> "22 linear_1" [label="(-1, 5)", style=solid];
43+
"17 scale_updated_constant1" -> "19 mul_tensor_1" [label="(10, 1)", style=solid];
44+
"18 compressed_weight_updated_constant1" -> "19 mul_tensor_1" [label="(10, 5)", style=solid];
45+
"19 mul_tensor_1" -> "21 sub_tensor_1" [label="(10, 5)", style=solid];
46+
"20 zero_point_updated_constant1" -> "21 sub_tensor_1" [label="(10, 1)", style=solid];
47+
"21 sub_tensor_1" -> "22 linear_1" [label="(10, 5)", style=solid];
48+
"22 linear_1" -> "23 output" [label="(-1, 10)", style=solid];
49+
}

tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/unet.dot

+515
Large diffs are not rendered by default.

tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/vit_b_16.dot

+2,011
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
strict digraph {
2+
"0 x" [id=0, type=input];
3+
"1 x_0_0_nncf_smooth_quant_0" [id=1, type=call_module];
4+
"2 quantize_per_tensor_default" [id=2, type=quantize_per_tensor];
5+
"3 dequantize_per_tensor_default" [id=3, type=dequantize_per_tensor];
6+
"4 scale_updated_constant0" [id=4, type=get_attr];
7+
"5 compressed_weight_updated_constant0" [id=5, type=get_attr];
8+
"6 mul_tensor" [id=6, type=mul];
9+
"7 zero_point_updated_constant0" [id=7, type=get_attr];
10+
"8 sub_tensor" [id=8, type=sub];
11+
"9 linear" [id=9, type=linear];
12+
"10 quantize_per_tensor_default_1" [id=10, type=quantize_per_tensor];
13+
"11 dequantize_per_tensor_default_1" [id=11, type=dequantize_per_tensor];
14+
"12 slice_1" [id=12, type=slice];
15+
"13 slice_2" [id=13, type=slice];
16+
"14 slice_3" [id=14, type=slice];
17+
"15 quantize_per_tensor_default_2" [id=15, type=quantize_per_tensor];
18+
"16 dequantize_per_tensor_default_2" [id=16, type=dequantize_per_tensor];
19+
"17 slice_4" [id=17, type=slice];
20+
"18 slice_5" [id=18, type=slice];
21+
"19 slice_6" [id=19, type=slice];
22+
"20 slice_7" [id=20, type=slice];
23+
"21 slice_8" [id=21, type=slice];
24+
"22 slice_9" [id=22, type=slice];
25+
"23 transpose" [id=23, type=transpose];
26+
"24 matmul" [id=24, type=matmul];
27+
"25 div_" [id=25, type=div_];
28+
"26 softmax" [id=26, type=softmax];
29+
"27 transpose_1" [id=27, type=transpose];
30+
"28 matmul_1" [id=28, type=matmul];
31+
"29 output" [id=29, type=output];
32+
"0 x" -> "1 x_0_0_nncf_smooth_quant_0" [label="(1, -1, -1)", style=solid];
33+
"1 x_0_0_nncf_smooth_quant_0" -> "2 quantize_per_tensor_default" [label="(1, -1, 4)", style=solid];
34+
"2 quantize_per_tensor_default" -> "3 dequantize_per_tensor_default" [label="(1, -1, 4)", style=solid];
35+
"3 dequantize_per_tensor_default" -> "9 linear" [label="(1, -1, 4)", style=solid];
36+
"4 scale_updated_constant0" -> "6 mul_tensor" [label="(12, 1)", style=solid];
37+
"5 compressed_weight_updated_constant0" -> "6 mul_tensor" [label="(12, 4)", style=solid];
38+
"6 mul_tensor" -> "8 sub_tensor" [label="(12, 4)", style=solid];
39+
"7 zero_point_updated_constant0" -> "8 sub_tensor" [label="(12, 1)", style=solid];
40+
"8 sub_tensor" -> "9 linear" [label="(12, 4)", style=solid];
41+
"9 linear" -> "10 quantize_per_tensor_default_1" [label="(1, -1, 12)", style=solid];
42+
"9 linear" -> "15 quantize_per_tensor_default_2" [label="(1, -1, 12)", style=solid];
43+
"9 linear" -> "20 slice_7" [label="(1, -1, 12)", style=solid];
44+
"10 quantize_per_tensor_default_1" -> "11 dequantize_per_tensor_default_1" [label="(1, -1, 12)", style=solid];
45+
"11 dequantize_per_tensor_default_1" -> "12 slice_1" [label="(1, -1, 12)", style=solid];
46+
"12 slice_1" -> "13 slice_2" [label="(1, -1, 12)", style=solid];
47+
"13 slice_2" -> "14 slice_3" [label="(1, -1, 12)", style=solid];
48+
"14 slice_3" -> "24 matmul" [label="(1, -1, 4)", style=solid];
49+
"15 quantize_per_tensor_default_2" -> "16 dequantize_per_tensor_default_2" [label="(1, -1, 12)", style=solid];
50+
"16 dequantize_per_tensor_default_2" -> "17 slice_4" [label="(1, -1, 12)", style=solid];
51+
"17 slice_4" -> "18 slice_5" [label="(1, -1, 12)", style=solid];
52+
"18 slice_5" -> "19 slice_6" [label="(1, -1, 12)", style=solid];
53+
"19 slice_6" -> "23 transpose" [label="(1, -1, 4)", style=solid];
54+
"20 slice_7" -> "21 slice_8" [label="(1, -1, 12)", style=solid];
55+
"21 slice_8" -> "22 slice_9" [label="(1, -1, 12)", style=solid];
56+
"22 slice_9" -> "28 matmul_1" [label="(1, -1, 4)", style=solid];
57+
"23 transpose" -> "24 matmul" [label="(1, 4, -1)", style=solid];
58+
"24 matmul" -> "25 div_" [label="(1, -1, -1)", style=solid];
59+
"25 div_" -> "26 softmax" [label="(1, -1, -1)", style=solid];
60+
"26 softmax" -> "27 transpose_1" [label="(1, -1, -1)", style=solid];
61+
"27 transpose_1" -> "28 matmul_1" [label="(1, -1, -1)", style=solid];
62+
"28 matmul_1" -> "29 output" [label="(1, -1, 4)", style=solid];
63+
}

0 commit comments

Comments
 (0)