3
3
4
4
import pytest
5
5
import torch
6
- import torchvision
7
6
import transformers
8
7
9
8
from neural_compressor .torch .algorithms .fp8_quant ._quant_common .helper_modules import PatchedConv2d , PatchedLinear
@@ -36,14 +35,12 @@ def calib_func(model):
36
35
37
36
38
37
@pytest .mark .skipif (not is_hpex_available (), reason = "HPU environment is required!" )
39
- class TestFP8StaticQuant :
38
+ class TestFP8StaticQuantNLP :
40
39
def setup_class (self ):
41
40
change_to_cur_file_dir ()
42
41
config = transformers .AutoConfig .from_pretrained ("./model_configs/tiny_gptj.json" )
43
42
self .tiny_gptj = transformers .AutoModelForCausalLM .from_config (config )
44
43
self .example_inputs = torch .tensor ([[10 , 20 , 30 , 40 , 50 , 60 ]], dtype = torch .long ).to ("hpu" )
45
- self .resnet18 = torchvision .models .resnet18 ()
46
- self .cv_dummy_inputs = torch .randn ([1 , 3 , 224 , 224 ]).to ("hpu" )
47
44
48
45
def teardown_class (self ):
49
46
shutil .rmtree ("test_ouputs" , ignore_errors = True )
@@ -72,6 +69,38 @@ def test_one_step_quant_nlp(self):
72
69
), "k_proj input dtype is not torch.float8_e4m3fn."
73
70
assert (fp32_out != fp8_out ).any (), "FP32 output should be different with FP8 output"
74
71
72
+ @torch .no_grad ()
73
+ def test_two_step_quant_nlp (self ):
74
+ # step 1: measurement
75
+ model = copy .deepcopy (self .tiny_gptj )
76
+ config = FP8Config .from_json_file ("test_fp8_jsons/test_measure.json" )
77
+ model = prepare (model , config )
78
+ calib_func (model )
79
+ finalize_calibration (model )
80
+ assert isinstance (model .transformer .h [0 ].attn .k_proj , PatchedLinear ), "k_proj is not observed."
81
+ # step 2: quantize based on measurement
82
+ model = copy .deepcopy (self .tiny_gptj )
83
+ config = FP8Config .from_json_file ("test_fp8_jsons/test_hw_quant.json" )
84
+ model = convert (model , config )
85
+ assert isinstance (model .transformer .h [0 ].attn .k_proj , PatchedLinear ), "k_proj is not quantized."
86
+ assert (
87
+ model .transformer .h [0 ].attn .k_proj .quant_input .lp_dtype == torch .float8_e4m3fn
88
+ ), "k_proj input dtype is not torch.float8_e4m3fn."
89
+
90
+
91
+ @pytest .mark .xfail (reason = "[SW-219514] RuntimeError: operator torchvision::nms does not exist" )
92
+ @pytest .mark .skipif (not is_hpex_available (), reason = "HPU environment is required!" )
93
+ class TestFP8StaticQuantCV :
94
+ def setup_class (self ):
95
+ change_to_cur_file_dir ()
96
+ import torchvision
97
+ self .resnet18 = torchvision .models .resnet18 ()
98
+ self .cv_dummy_inputs = torch .randn ([1 , 3 , 224 , 224 ]).to ("hpu" )
99
+
100
+ def teardown_class (self ):
101
+ shutil .rmtree ("test_ouputs" , ignore_errors = True )
102
+ shutil .rmtree ("saved_results" , ignore_errors = True )
103
+
75
104
@torch .no_grad ()
76
105
def test_one_step_quant_cv (self ):
77
106
model = copy .deepcopy (self .resnet18 )
@@ -94,24 +123,6 @@ def test_one_step_quant_cv(self):
94
123
), "model is not quantized to torch.float8_e4m3fn."
95
124
assert (fp32_out != fp8_out ).any (), "FP32 output should be different with FP8 output"
96
125
97
- @torch .no_grad ()
98
- def test_two_step_quant_nlp (self ):
99
- # step 1: measurement
100
- model = copy .deepcopy (self .tiny_gptj )
101
- config = FP8Config .from_json_file ("test_fp8_jsons/test_measure.json" )
102
- model = prepare (model , config )
103
- calib_func (model )
104
- finalize_calibration (model )
105
- assert isinstance (model .transformer .h [0 ].attn .k_proj , PatchedLinear ), "k_proj is not observed."
106
- # step 2: quantize based on measurement
107
- model = copy .deepcopy (self .tiny_gptj )
108
- config = FP8Config .from_json_file ("test_fp8_jsons/test_hw_quant.json" )
109
- model = convert (model , config )
110
- assert isinstance (model .transformer .h [0 ].attn .k_proj , PatchedLinear ), "k_proj is not quantized."
111
- assert (
112
- model .transformer .h [0 ].attn .k_proj .quant_input .lp_dtype == torch .float8_e4m3fn
113
- ), "k_proj input dtype is not torch.float8_e4m3fn."
114
-
115
126
@torch .no_grad ()
116
127
def test_two_step_quant_cv (self ):
117
128
# step 1: measurement
0 commit comments