17
17
prepare ,
18
18
quantize ,
19
19
)
20
- from neural_compressor .torch .utils import TORCH_VERSION_2_2_2 , get_torch_version
20
+ from neural_compressor .torch .utils import GT_TORCH_VERSION_2_3_2 , TORCH_VERSION_2_2_2 , get_torch_version
21
21
22
22
torch .manual_seed (0 )
23
23
@@ -119,6 +119,42 @@ def calib_fn(model):
119
119
logger .warning ("out shape is %s" , out .shape )
120
120
assert out is not None
121
121
122
+ @pytest .mark .skipif (not GT_TORCH_VERSION_2_3_2 , reason = "Requires torch>=2.3.2" )
123
+ def test_quantize_simple_model_with_set_local (self , force_not_import_ipex ):
124
+ model , example_inputs = self .build_simple_torch_model_and_example_inputs ()
125
+ float_model_output = model (* example_inputs )
126
+ quant_config = None
127
+
128
+ def calib_fn (model ):
129
+ for i in range (4 ):
130
+ model (* example_inputs )
131
+
132
+ quant_config = get_default_static_config ()
133
+ quant_config .set_local ("fc1" , StaticQuantConfig (w_dtype = "fp32" , act_dtype = "fp32" ))
134
+ q_model = quantize (model = model , quant_config = quant_config , run_fn = calib_fn )
135
+
136
+ # check the half node
137
+ expected_node_occurrence = {
138
+ # Only quantize the `fc2`
139
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
140
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
141
+ }
142
+ expected_node_occurrence = {
143
+ torch_test_quant_common .NodeSpec .call_function (k ): v for k , v in expected_node_occurrence .items ()
144
+ }
145
+ node_in_graph = self .get_node_in_graph (q_model )
146
+ for node , cnt in expected_node_occurrence .items ():
147
+ assert node_in_graph .get (node , 0 ) == cnt , f"Node { node } should occur { cnt } times, but { node_in_graph [node ]} "
148
+
149
+ from torch ._inductor import config
150
+
151
+ config .freezing = True
152
+ q_model_out = q_model (* example_inputs )
153
+ assert torch .allclose (float_model_output , q_model_out , atol = 1e-2 ), "Quantization failed!"
154
+ opt_model = torch .compile (q_model )
155
+ out = opt_model (* example_inputs )
156
+ assert out is not None
157
+
122
158
@pytest .mark .skipif (get_torch_version () <= TORCH_VERSION_2_2_2 , reason = "Requires torch>=2.3.0" )
123
159
@pytest .mark .parametrize ("is_dynamic" , [False , True ])
124
160
def test_prepare_and_convert_on_simple_model (self , is_dynamic , force_not_import_ipex ):
@@ -193,9 +229,9 @@ def get_node_in_graph(graph_module):
193
229
nodes_in_graph [n ] += 1
194
230
else :
195
231
nodes_in_graph [n ] = 1
196
- return
232
+ return nodes_in_graph
197
233
198
- @pytest .mark .skipif (get_torch_version () <= TORCH_VERSION_2_2_2 , reason = "Requires torch>=2.3.0" )
234
+ @pytest .mark .skipif (not GT_TORCH_VERSION_2_3_2 , reason = "Requires torch>=2.3.0" )
199
235
def test_mixed_fp16_and_int8 (self , force_not_import_ipex ):
200
236
model , example_inputs = self .build_model_include_conv_and_linear ()
201
237
model = export (model , example_inputs = example_inputs )
@@ -221,9 +257,7 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
221
257
}
222
258
node_in_graph = self .get_node_in_graph (converted_model )
223
259
for node , cnt in expected_node_occurrence .items ():
224
- assert (
225
- expected_node_occurrence .get (node , 0 ) == cnt
226
- ), f"Node { node } should occur { cnt } times, but { node_in_graph [node ]} "
260
+ assert node_in_graph .get (node , 0 ) == cnt , f"Node { node } should occur { cnt } times, but { node_in_graph [node ]} "
227
261
228
262
# inference
229
263
from torch ._inductor import config
0 commit comments