@@ -158,12 +158,7 @@ def validate(self, model: torch.fx.GraphModule) -> None:
158
158
),
159
159
)
160
160
161
- def test_composable_quantizer_linear_conv (self ) -> None :
162
- # TODO: impelment whe dynamic quantization will be supported by OpenVINOQuantizer
163
- pass
164
-
165
161
def test_embedding_conv_linear_quantization (self ) -> None :
166
- # Mark
167
162
m_eager = TestHelperModules .EmbeddingConvLinearModule ().eval ()
168
163
indices = torch .tensor (
169
164
[
@@ -203,57 +198,87 @@ def test_embedding_conv_linear_quantization(self) -> None:
203
198
)
204
199
indices = torch .unsqueeze (indices , 0 )
205
200
example_inputs = (indices ,)
201
+ quantizer = OpenVINOQuantizer ()
206
202
207
- embedding_quantizer = EmbeddingQuantizer ()
208
- dynamic_quantizer = XNNPACKQuantizer ()
209
- quantization_config_dynamic = get_symmetric_quantization_config (
210
- is_per_channel = True , is_dynamic = True
211
- )
212
- dynamic_quantizer .set_global (quantization_config_dynamic )
213
- static_quantizer = XNNPACKQuantizer ()
214
- quantization_config = get_symmetric_quantization_config (is_per_channel = True )
215
- static_quantizer .set_global (quantization_config )
216
- composed_quantizer = ComposableQuantizer (
217
- [embedding_quantizer , dynamic_quantizer , static_quantizer ]
218
- )
203
+ m = self ._quantize (m_eager , quantizer , example_inputs , is_qat = False )
219
204
220
- act_affine_quant_obs = observer .PlaceholderObserver .with_args (
221
- dtype = torch .qint8 ,
222
- qscheme = torch .per_tensor_affine ,
223
- quant_min = - 128 ,
224
- quant_max = 127 ,
225
- eps = 2 ** - 12 ,
226
- is_dynamic = True ,
227
- )
228
- dynamic_qconfig = QConfig (
229
- activation = act_affine_quant_obs ,
230
- weight = per_channel_weight_observer_range_neg_127_to_127 ,
231
- )
232
- qconfig = default_per_channel_symmetric_qnnpack_qconfig
233
- qconfig_mapping = QConfigMapping ().set_global (qconfig )
234
- qconfig_mapping .set_object_type (torch .nn .Linear , dynamic_qconfig )
235
- qconfig_mapping = qconfig_mapping .set_object_type (
236
- torch .nn .Embedding , float_qparams_weight_only_qconfig
237
- )
238
-
239
- node_occurrence = {
240
- torch .ops .quantized_decomposed .quantize_per_tensor .default : 4 ,
241
- torch .ops .quantized_decomposed .dequantize_per_tensor .default : 4 ,
242
- torch .ops .quantized_decomposed .quantize_per_tensor .tensor : 1 ,
243
- torch .ops .quantized_decomposed .dequantize_per_tensor .tensor : 1 ,
244
- # note: quantize op for weights are const propagated
245
- torch .ops .quantized_decomposed .quantize_per_channel .default : 0 ,
246
- torch .ops .quantized_decomposed .dequantize_per_channel .default : 3 ,
205
+ ref_q = {
206
+ # First conv
207
+ "quantize_per_tensor_default" : (
208
+ None ,
209
+ 0.01585131697356701 ,
210
+ 127 ,
211
+ 0 ,
212
+ 255 ,
213
+ torch .uint8 ,
214
+ ),
215
+ "dequantize_per_tensor_default" : (
216
+ None ,
217
+ 0.01585131697356701 ,
218
+ 127 ,
219
+ 0 ,
220
+ 255 ,
221
+ torch .uint8 ,
222
+ ),
223
+ "dequantize_per_channel_default" : (
224
+ None ,
225
+ torch .tensor (
226
+ [
227
+ 0.0015 ,
228
+ 0.0015 ,
229
+ 0.0015 ,
230
+ 0.0016 ,
231
+ 0.0015 ,
232
+ 0.0016 ,
233
+ 0.0014 ,
234
+ 0.0014 ,
235
+ 0.0015 ,
236
+ 0.0015 ,
237
+ 0.0016 ,
238
+ 0.0015 ,
239
+ 0.0015 ,
240
+ 0.0016 ,
241
+ 0.0016 ,
242
+ 0.0015 ,
243
+ ]
244
+ ),
245
+ torch .tensor ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]),
246
+ 0 ,
247
+ - 128 ,
248
+ 127 ,
249
+ torch .int8 ,
250
+ ),
251
+ # First linear
252
+ "quantize_per_tensor_default_1" : (
253
+ None ,
254
+ 0.016017982736229897 ,
255
+ 127 ,
256
+ 0 ,
257
+ 255 ,
258
+ torch .uint8 ,
259
+ ),
260
+ "dequantize_per_tensor_default_1" : (
261
+ None ,
262
+ 0.016017982736229897 ,
263
+ 127 ,
264
+ 0 ,
265
+ 255 ,
266
+ torch .uint8 ,
267
+ ),
268
+ "dequantize_per_channel_default_1" : (
269
+ None ,
270
+ torch .tensor (
271
+ [0.0019 , 0.0019 , 0.0020 , 0.0018 , 0.0019 , 0.0019 , 0.0018 , 0.0018 ]
272
+ ),
273
+ torch .tensor ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]),
274
+ 0 ,
275
+ - 128 ,
276
+ 127 ,
277
+ torch .int8 ,
278
+ ),
279
+ # TODO: embedding
247
280
}
248
- self ._test_quantizer (
249
- m_eager ,
250
- example_inputs ,
251
- composed_quantizer ,
252
- node_occurrence ,
253
- [],
254
- True ,
255
- qconfig_mapping ,
256
- )
281
+ self ._check_quantization_with_ref (m , ref_q )
257
282
258
283
def test_disallow_eval_train (self ) -> None :
259
284
m = TestHelperModules .ConvWithBNRelu (relu = True )
@@ -272,7 +297,7 @@ def test_disallow_eval_train(self) -> None:
272
297
273
298
# After prepare: still not OK
274
299
quantizer = OpenVINOQuantizer ()
275
- m = prepare_qat_pt2e (m , quantizer ) # pyre-ignore[6]
300
+ m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
276
301
with self .assertRaises (NotImplementedError ):
277
302
m .eval ()
278
303
with self .assertRaises (NotImplementedError ):
@@ -308,11 +333,9 @@ class M(torch.nn.Module):
308
333
def __init__ (self ) -> None :
309
334
super ().__init__ ()
310
335
self .bn = torch .nn .BatchNorm2d (3 )
311
- self .dropout = torch .nn .Dropout (0.5 )
312
336
313
337
def forward (self , x ):
314
338
x = self .bn (x )
315
- x = self .dropout (x )
316
339
return x
317
340
318
341
m = M ().train ()
@@ -324,8 +347,6 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
324
347
bn_op = bn_train_op if train else bn_eval_op
325
348
bn_node = self ._get_node (m , bn_op )
326
349
self .assertTrue (bn_node is not None )
327
- dropout_node = self ._get_node (m , torch .ops .aten .dropout .default )
328
- self .assertEqual (dropout_node .args [2 ], train )
329
350
330
351
# Before wrapping: this is not OK
331
352
with self .assertRaises (NotImplementedError ):
@@ -341,8 +362,8 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
341
362
_assert_ops_are_correct (m , train = True ) # pyre-ignore[6]
342
363
343
364
# After prepare but before wrapping: this is not OK
344
- quantizer = XNNPACKQuantizer ()
345
- m = prepare_qat_pt2e (m , quantizer ) # pyre-ignore[6]
365
+ quantizer = OpenVINOQuantizer ()
366
+ m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
346
367
with self .assertRaises (NotImplementedError ):
347
368
m .eval ()
348
369
with self .assertRaises (NotImplementedError ):
@@ -677,142 +698,6 @@ def _check_quantization_with_ref(self, model: torch.fx.GraphModule, ref: Dict):
677
698
678
699
assert len (ref ) == matches
679
700
680
- def _get_backend_config (self ):
681
- def _get_linear_configs ():
682
- observation_type = ObservationType .OUTPUT_SHARE_OBSERVER_WITH_INPUT
683
- dtype_configs = [
684
- DTypeConfig (
685
- input_dtype = torch .quint8 ,
686
- output_dtype = torch .float ,
687
- weight_dtype = torch .qint8 ,
688
- bias_dtype = torch .float ,
689
- )
690
- ]
691
- linear_configs : list [BackendPatternConfig ] = []
692
- # linear module
693
- linear_configs .append (
694
- BackendPatternConfig (torch .nn .Linear )
695
- .set_observation_type (observation_type ) # noqa: E131
696
- .set_dtype_configs (dtype_configs )
697
- .set_root_module (torch .nn .Linear )
698
- .set_reference_quantized_module (nnqr .Linear )
699
- )
700
- # functional linear
701
- linear_configs .append (
702
- BackendPatternConfig (torch .nn .functional .linear )
703
- .set_observation_type (observation_type ) # noqa: E131
704
- .set_dtype_configs (dtype_configs )
705
- ._set_input_type_to_index ({"weight" : 1 , "bias" : 2 })
706
- )
707
- return linear_configs
708
-
709
- def _get_conv_configs ():
710
- pass
711
-
712
- return BackendConfig ("OpenVINO" ).set_backend_pattern_configs (
713
- _get_linear_configs ()
714
- )
715
- # .set_backend_pattern_configs(_get_conv_configs())
716
-
717
- def _test_quantizer (
718
- self ,
719
- model ,
720
- example_inputs ,
721
- quantizer ,
722
- expected_node_occurrence ,
723
- expected_node_list = None ,
724
- check_against_fx_quant = False ,
725
- fx_qconfig_mapping = None ,
726
- export_with_dynamic_shape = False ,
727
- is_qat = False ,
728
- is_debug_mode = False ,
729
- training_ir_node_occurrence = None ,
730
- ):
731
- # resetting dynamo cache
732
- torch ._dynamo .reset ()
733
- m_eager = model .eval ()
734
-
735
- # program capture
736
- m = copy .deepcopy (m_eager )
737
- dynamic_shapes = tuple (
738
- {0 : torch .export .Dim ("dim" )} if i == 0 else None
739
- for i in range (len (example_inputs ))
740
- )
741
- m = export_for_training (
742
- m ,
743
- example_inputs ,
744
- dynamic_shapes = dynamic_shapes if export_with_dynamic_shape else None ,
745
- ).module ()
746
-
747
- if is_qat :
748
- m = prepare_qat_pt2e (m , quantizer )
749
- else :
750
- m = prepare_pt2e (m , quantizer )
751
- if is_debug_mode :
752
- print ("prepared model:" , m )
753
- # Calibrate
754
- m (* example_inputs )
755
- m = convert_pt2e (m )
756
- if is_debug_mode :
757
- print ("quantized model" , m )
758
-
759
- pt2_quant_output = m (* example_inputs )
760
- node_occurrence = {
761
- ns .call_function (k ): v for k , v in expected_node_occurrence .items ()
762
- }
763
- if expected_node_list is None :
764
- expected_node_list = []
765
- node_list = [ns .call_function (n ) for n in expected_node_list ]
766
- self .checkGraphModuleNodes (
767
- m , expected_node_occurrence = node_occurrence , expected_node_list = node_list
768
- )
769
- if check_against_fx_quant :
770
- qconfig_mapping = fx_qconfig_mapping
771
- backend_config = self ._get_backend_config ()
772
- m_copy = copy .deepcopy (m_eager )
773
- m_fx = prepare_fx (
774
- m_copy , qconfig_mapping , example_inputs , backend_config = backend_config
775
- )
776
- m_fx (* example_inputs )
777
- m_fx = _convert_to_reference_decomposed_fx (
778
- m_fx , backend_config = backend_config
779
- )
780
- m_fx = export_for_training (
781
- m_fx ,
782
- example_inputs ,
783
- dynamic_shapes = dynamic_shapes if export_with_dynamic_shape else None ,
784
- ).module ()
785
- node_occurrence = {}
786
- for k , v in PT2EQuantizationTestCase ._MAP_TO_FX_TRACED_OPS .items ():
787
- if k in expected_node_occurrence :
788
- node_occurrence [ns .call_function (v )] = expected_node_occurrence [k ]
789
- if training_ir_node_occurrence is not None :
790
- node_occurrence = {
791
- ns .call_function (k ): v
792
- for k , v in training_ir_node_occurrence .items ()
793
- }
794
- self .checkGraphModuleNodes (m_fx , expected_node_occurrence = node_occurrence )
795
- fx_quant_output = m_fx (* example_inputs )
796
- self .assertEqual (fx_quant_output , pt2_quant_output )
797
- return m
798
- # activation_observer = observer.HistogramObserver
799
- default_qconfig = QConfig (
800
- activation = activation_observer , weight = weight_observer
801
- )
802
- qconfig_mapping = QConfigMapping ()
803
- qconfig_mapping .set_global (QConfig (activation = None , weight = None ))
804
- qconfig_mapping .set_object_type (torch .nn .Linear , default_qconfig )
805
- self ._quantize ()
806
- self ._test_quantizer (
807
- m ,
808
- example_inputs ,
809
- quantizer ,
810
- node_occurrence ,
811
- check_against_fx_quant = True ,
812
- fx_qconfig_mapping = qconfig_mapping ,
813
- )
814
- # self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, )
815
-
816
701
def test_save_load (self ) -> None :
817
702
"""Test save/load a quantized model"""
818
703
m = self ._get_linear ()
0 commit comments