22
22
from tests .onnx .quantization .common import mock_collect_statistics
23
23
from tests .onnx .weightless_model import load_model_topology_with_zeros_weights
24
24
25
+
26
+ def model_builder (model_name ):
27
+ if model_name == "resnet18" :
28
+ return models .resnet18 (weights = models .ResNet18_Weights .IMAGENET1K_V1 )
29
+ if model_name == "resnet50_cpu_spr" :
30
+ return models .resnet50 (weights = models .ResNet50_Weights .IMAGENET1K_V1 )
31
+ if model_name == "mobilenet_v2" :
32
+ return models .mobilenet_v2 (weights = models .MobileNet_V2_Weights .IMAGENET1K_V1 )
33
+ if model_name == "mobilenet_v3_small" :
34
+ return models .mobilenet_v3_small (weights = models .MobileNet_V3_Small_Weights .IMAGENET1K_V1 )
35
+ if model_name == "inception_v3" :
36
+ return models .inception_v3 (weights = models .Inception_V3_Weights .IMAGENET1K_V1 )
37
+ if model_name == "googlenet" :
38
+ return models .googlenet (weights = models .GoogLeNet_Weights .IMAGENET1K_V1 )
39
+ if model_name == "vgg16" :
40
+ return models .vgg16 (weights = models .VGG16_Weights .IMAGENET1K_V1 )
41
+ if model_name == "shufflenet_v2_x1_0" :
42
+ return models .shufflenet_v2_x1_0 (weights = models .ShuffleNet_V2_X1_0_Weights .IMAGENET1K_V1 )
43
+ if model_name == "squeezenet1_0" :
44
+ return models .squeezenet1_0 (weights = models .SqueezeNet1_0_Weights .IMAGENET1K_V1 )
45
+ if model_name == "densenet121" :
46
+ return models .densenet121 (weights = models .DenseNet121_Weights .IMAGENET1K_V1 )
47
+ if model_name == "mnasnet0_5" :
48
+ return models .mnasnet0_5 (weights = models .MNASNet0_5_Weights .IMAGENET1K_V1 )
49
+ raise ValueError (f"Unknown model name { model_name } " )
50
+
51
+
25
52
TORCHVISION_TEST_DATA = [
26
- (
27
- ModelToTest ("resnet18" , [1 , 3 , 224 , 224 ]),
28
- models .resnet18 (weights = models .ResNet18_Weights .IMAGENET1K_V1 ),
29
- {},
30
- ),
31
- (
32
- ModelToTest ("resnet50_cpu_spr" , [1 , 3 , 224 , 224 ]),
33
- models .resnet50 (weights = models .ResNet50_Weights .IMAGENET1K_V1 ),
34
- {"target_device" : TargetDevice .CPU_SPR },
35
- ),
36
- (
37
- ModelToTest ("mobilenet_v2" , [1 , 3 , 224 , 224 ]),
38
- models .mobilenet_v2 (weights = models .MobileNet_V2_Weights .IMAGENET1K_V1 ),
39
- {},
40
- ),
41
- (
42
- ModelToTest ("mobilenet_v3_small" , [1 , 3 , 224 , 224 ]),
43
- models .mobilenet_v3_small (weights = models .MobileNet_V3_Small_Weights .IMAGENET1K_V1 ),
44
- {},
45
- ),
46
- (
47
- ModelToTest ("inception_v3" , [1 , 3 , 224 , 224 ]),
48
- models .inception_v3 (weights = models .Inception_V3_Weights .IMAGENET1K_V1 ),
49
- {},
50
- ),
51
- (
52
- ModelToTest ("googlenet" , [1 , 3 , 224 , 224 ]),
53
- models .googlenet (weights = models .GoogLeNet_Weights .IMAGENET1K_V1 ),
54
- {},
55
- ),
56
- (
57
- ModelToTest ("vgg16" , [1 , 3 , 224 , 224 ]),
58
- models .vgg16 (weights = models .VGG16_Weights .IMAGENET1K_V1 ),
59
- {},
60
- ),
61
- (
62
- ModelToTest ("shufflenet_v2_x1_0" , [1 , 3 , 224 , 224 ]),
63
- models .shufflenet_v2_x1_0 (weights = models .ShuffleNet_V2_X1_0_Weights .IMAGENET1K_V1 ),
64
- {},
65
- ),
66
- (
67
- ModelToTest ("squeezenet1_0" , [1 , 3 , 224 , 224 ]),
68
- models .squeezenet1_0 (weights = models .SqueezeNet1_0_Weights .IMAGENET1K_V1 ),
69
- {},
70
- ),
71
- (
72
- ModelToTest ("densenet121" , [1 , 3 , 224 , 224 ]),
73
- models .densenet121 (weights = models .DenseNet121_Weights .IMAGENET1K_V1 ),
74
- {},
75
- ),
76
- (
77
- ModelToTest ("mnasnet0_5" , [1 , 3 , 224 , 224 ]),
78
- models .mnasnet0_5 (weights = models .MNASNet0_5_Weights .IMAGENET1K_V1 ),
79
- {},
80
- ),
53
+ (ModelToTest ("resnet18" , [1 , 3 , 224 , 224 ]), {}),
54
+ (ModelToTest ("resnet50_cpu_spr" , [1 , 3 , 224 , 224 ]), {"target_device" : TargetDevice .CPU_SPR }),
55
+ (ModelToTest ("mobilenet_v2" , [1 , 3 , 224 , 224 ]), {}),
56
+ (ModelToTest ("mobilenet_v3_small" , [1 , 3 , 224 , 224 ]), {}),
57
+ (ModelToTest ("inception_v3" , [1 , 3 , 224 , 224 ]), {}),
58
+ (ModelToTest ("googlenet" , [1 , 3 , 224 , 224 ]), {}),
59
+ (ModelToTest ("vgg16" , [1 , 3 , 224 , 224 ]), {}),
60
+ (ModelToTest ("shufflenet_v2_x1_0" , [1 , 3 , 224 , 224 ]), {}),
61
+ (ModelToTest ("squeezenet1_0" , [1 , 3 , 224 , 224 ]), {}),
62
+ (ModelToTest ("densenet121" , [1 , 3 , 224 , 224 ]), {}),
63
+ (ModelToTest ("mnasnet0_5" , [1 , 3 , 224 , 224 ]), {}),
81
64
]
82
65
83
66
84
67
@pytest .mark .parametrize (
85
- ("model_to_test" , "model" , " quantization_parameters" ),
68
+ ("model_to_test" , "quantization_parameters" ),
86
69
TORCHVISION_TEST_DATA ,
87
70
ids = [model_to_test [0 ].model_name for model_to_test in TORCHVISION_TEST_DATA ],
88
71
)
89
- def test_min_max_quantization_graph_torchvision_models (tmp_path , mocker , model_to_test , model , quantization_parameters ):
72
+ def test_min_max_quantization_graph_torchvision_models (tmp_path , mocker , model_to_test , quantization_parameters ):
90
73
mock_collect_statistics (mocker )
74
+ model = model_builder (model_to_test .model_name )
91
75
onnx_model_path = tmp_path / (model_to_test .model_name + ".onnx" )
92
76
x = torch .randn (model_to_test .input_shape , requires_grad = False )
93
77
torch .onnx .export (model , x , onnx_model_path , opset_version = 13 )
@@ -105,6 +89,7 @@ def test_min_max_quantization_graph_torchvision_models(tmp_path, mocker, model_t
105
89
)
106
90
def test_min_max_quantization_graph_onnx_model (tmp_path , mocker , model_to_test ):
107
91
mock_collect_statistics (mocker )
92
+
108
93
onnx_model_path = ONNX_MODEL_DIR / (model_to_test .model_name + ".onnx" )
109
94
original_model = load_model_topology_with_zeros_weights (onnx_model_path )
110
95
0 commit comments