Skip to content

Commit 42fd1c3

Browse files
WIP
1 parent 4857a74 commit 42fd1c3

File tree

2 files changed

+81
-338
lines changed

2 files changed

+81
-338
lines changed

backends/openvino/tests/quantizer/test_pt2e_quantization.py

+81-196
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,7 @@ def validate(self, model: torch.fx.GraphModule) -> None:
158158
),
159159
)
160160

161-
def test_composable_quantizer_linear_conv(self) -> None:
162-
# TODO: impelment whe dynamic quantization will be supported by OpenVINOQuantizer
163-
pass
164-
165161
def test_embedding_conv_linear_quantization(self) -> None:
166-
# Mark
167162
m_eager = TestHelperModules.EmbeddingConvLinearModule().eval()
168163
indices = torch.tensor(
169164
[
@@ -203,57 +198,87 @@ def test_embedding_conv_linear_quantization(self) -> None:
203198
)
204199
indices = torch.unsqueeze(indices, 0)
205200
example_inputs = (indices,)
201+
quantizer = OpenVINOQuantizer()
206202

207-
embedding_quantizer = EmbeddingQuantizer()
208-
dynamic_quantizer = XNNPACKQuantizer()
209-
quantization_config_dynamic = get_symmetric_quantization_config(
210-
is_per_channel=True, is_dynamic=True
211-
)
212-
dynamic_quantizer.set_global(quantization_config_dynamic)
213-
static_quantizer = XNNPACKQuantizer()
214-
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
215-
static_quantizer.set_global(quantization_config)
216-
composed_quantizer = ComposableQuantizer(
217-
[embedding_quantizer, dynamic_quantizer, static_quantizer]
218-
)
203+
m = self._quantize(m_eager, quantizer, example_inputs, is_qat=False)
219204

220-
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
221-
dtype=torch.qint8,
222-
qscheme=torch.per_tensor_affine,
223-
quant_min=-128,
224-
quant_max=127,
225-
eps=2**-12,
226-
is_dynamic=True,
227-
)
228-
dynamic_qconfig = QConfig(
229-
activation=act_affine_quant_obs,
230-
weight=per_channel_weight_observer_range_neg_127_to_127,
231-
)
232-
qconfig = default_per_channel_symmetric_qnnpack_qconfig
233-
qconfig_mapping = QConfigMapping().set_global(qconfig)
234-
qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig)
235-
qconfig_mapping = qconfig_mapping.set_object_type(
236-
torch.nn.Embedding, float_qparams_weight_only_qconfig
237-
)
238-
239-
node_occurrence = {
240-
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
241-
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
242-
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
243-
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
244-
# note: quantize op for weights are const propagated
245-
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
246-
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
205+
ref_q = {
206+
# First conv
207+
"quantize_per_tensor_default": (
208+
None,
209+
0.01585131697356701,
210+
127,
211+
0,
212+
255,
213+
torch.uint8,
214+
),
215+
"dequantize_per_tensor_default": (
216+
None,
217+
0.01585131697356701,
218+
127,
219+
0,
220+
255,
221+
torch.uint8,
222+
),
223+
"dequantize_per_channel_default": (
224+
None,
225+
torch.tensor(
226+
[
227+
0.0015,
228+
0.0015,
229+
0.0015,
230+
0.0016,
231+
0.0015,
232+
0.0016,
233+
0.0014,
234+
0.0014,
235+
0.0015,
236+
0.0015,
237+
0.0016,
238+
0.0015,
239+
0.0015,
240+
0.0016,
241+
0.0016,
242+
0.0015,
243+
]
244+
),
245+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
246+
0,
247+
-128,
248+
127,
249+
torch.int8,
250+
),
251+
# First linear
252+
"quantize_per_tensor_default_1": (
253+
None,
254+
0.016017982736229897,
255+
127,
256+
0,
257+
255,
258+
torch.uint8,
259+
),
260+
"dequantize_per_tensor_default_1": (
261+
None,
262+
0.016017982736229897,
263+
127,
264+
0,
265+
255,
266+
torch.uint8,
267+
),
268+
"dequantize_per_channel_default_1": (
269+
None,
270+
torch.tensor(
271+
[0.0019, 0.0019, 0.0020, 0.0018, 0.0019, 0.0019, 0.0018, 0.0018]
272+
),
273+
torch.tensor([0, 0, 0, 0, 0, 0, 0, 0]),
274+
0,
275+
-128,
276+
127,
277+
torch.int8,
278+
),
279+
# TODO: embedding
247280
}
248-
self._test_quantizer(
249-
m_eager,
250-
example_inputs,
251-
composed_quantizer,
252-
node_occurrence,
253-
[],
254-
True,
255-
qconfig_mapping,
256-
)
281+
self._check_quantization_with_ref(m, ref_q)
257282

258283
def test_disallow_eval_train(self) -> None:
259284
m = TestHelperModules.ConvWithBNRelu(relu=True)
@@ -272,7 +297,7 @@ def test_disallow_eval_train(self) -> None:
272297

273298
# After prepare: still not OK
274299
quantizer = OpenVINOQuantizer()
275-
m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6]
300+
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
276301
with self.assertRaises(NotImplementedError):
277302
m.eval()
278303
with self.assertRaises(NotImplementedError):
@@ -308,11 +333,9 @@ class M(torch.nn.Module):
308333
def __init__(self) -> None:
309334
super().__init__()
310335
self.bn = torch.nn.BatchNorm2d(3)
311-
self.dropout = torch.nn.Dropout(0.5)
312336

313337
def forward(self, x):
314338
x = self.bn(x)
315-
x = self.dropout(x)
316339
return x
317340

318341
m = M().train()
@@ -324,8 +347,6 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
324347
bn_op = bn_train_op if train else bn_eval_op
325348
bn_node = self._get_node(m, bn_op)
326349
self.assertTrue(bn_node is not None)
327-
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
328-
self.assertEqual(dropout_node.args[2], train)
329350

330351
# Before wrapping: this is not OK
331352
with self.assertRaises(NotImplementedError):
@@ -341,8 +362,8 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
341362
_assert_ops_are_correct(m, train=True) # pyre-ignore[6]
342363

343364
# After prepare but before wrapping: this is not OK
344-
quantizer = XNNPACKQuantizer()
345-
m = prepare_qat_pt2e(m, quantizer) # pyre-ignore[6]
365+
quantizer = OpenVINOQuantizer()
366+
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
346367
with self.assertRaises(NotImplementedError):
347368
m.eval()
348369
with self.assertRaises(NotImplementedError):
@@ -677,142 +698,6 @@ def _check_quantization_with_ref(self, model: torch.fx.GraphModule, ref: Dict):
677698

678699
assert len(ref) == matches
679700

680-
def _get_backend_config(self):
681-
def _get_linear_configs():
682-
observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
683-
dtype_configs = [
684-
DTypeConfig(
685-
input_dtype=torch.quint8,
686-
output_dtype=torch.float,
687-
weight_dtype=torch.qint8,
688-
bias_dtype=torch.float,
689-
)
690-
]
691-
linear_configs: list[BackendPatternConfig] = []
692-
# linear module
693-
linear_configs.append(
694-
BackendPatternConfig(torch.nn.Linear)
695-
.set_observation_type(observation_type) # noqa: E131
696-
.set_dtype_configs(dtype_configs)
697-
.set_root_module(torch.nn.Linear)
698-
.set_reference_quantized_module(nnqr.Linear)
699-
)
700-
# functional linear
701-
linear_configs.append(
702-
BackendPatternConfig(torch.nn.functional.linear)
703-
.set_observation_type(observation_type) # noqa: E131
704-
.set_dtype_configs(dtype_configs)
705-
._set_input_type_to_index({"weight": 1, "bias": 2})
706-
)
707-
return linear_configs
708-
709-
def _get_conv_configs():
710-
pass
711-
712-
return BackendConfig("OpenVINO").set_backend_pattern_configs(
713-
_get_linear_configs()
714-
)
715-
# .set_backend_pattern_configs(_get_conv_configs())
716-
717-
def _test_quantizer(
718-
self,
719-
model,
720-
example_inputs,
721-
quantizer,
722-
expected_node_occurrence,
723-
expected_node_list=None,
724-
check_against_fx_quant=False,
725-
fx_qconfig_mapping=None,
726-
export_with_dynamic_shape=False,
727-
is_qat=False,
728-
is_debug_mode=False,
729-
training_ir_node_occurrence=None,
730-
):
731-
# resetting dynamo cache
732-
torch._dynamo.reset()
733-
m_eager = model.eval()
734-
735-
# program capture
736-
m = copy.deepcopy(m_eager)
737-
dynamic_shapes = tuple(
738-
{0: torch.export.Dim("dim")} if i == 0 else None
739-
for i in range(len(example_inputs))
740-
)
741-
m = export_for_training(
742-
m,
743-
example_inputs,
744-
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
745-
).module()
746-
747-
if is_qat:
748-
m = prepare_qat_pt2e(m, quantizer)
749-
else:
750-
m = prepare_pt2e(m, quantizer)
751-
if is_debug_mode:
752-
print("prepared model:", m)
753-
# Calibrate
754-
m(*example_inputs)
755-
m = convert_pt2e(m)
756-
if is_debug_mode:
757-
print("quantized model", m)
758-
759-
pt2_quant_output = m(*example_inputs)
760-
node_occurrence = {
761-
ns.call_function(k): v for k, v in expected_node_occurrence.items()
762-
}
763-
if expected_node_list is None:
764-
expected_node_list = []
765-
node_list = [ns.call_function(n) for n in expected_node_list]
766-
self.checkGraphModuleNodes(
767-
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
768-
)
769-
if check_against_fx_quant:
770-
qconfig_mapping = fx_qconfig_mapping
771-
backend_config = self._get_backend_config()
772-
m_copy = copy.deepcopy(m_eager)
773-
m_fx = prepare_fx(
774-
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
775-
)
776-
m_fx(*example_inputs)
777-
m_fx = _convert_to_reference_decomposed_fx(
778-
m_fx, backend_config=backend_config
779-
)
780-
m_fx = export_for_training(
781-
m_fx,
782-
example_inputs,
783-
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
784-
).module()
785-
node_occurrence = {}
786-
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
787-
if k in expected_node_occurrence:
788-
node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
789-
if training_ir_node_occurrence is not None:
790-
node_occurrence = {
791-
ns.call_function(k): v
792-
for k, v in training_ir_node_occurrence.items()
793-
}
794-
self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
795-
fx_quant_output = m_fx(*example_inputs)
796-
self.assertEqual(fx_quant_output, pt2_quant_output)
797-
return m
798-
# activation_observer = observer.HistogramObserver
799-
default_qconfig = QConfig(
800-
activation=activation_observer, weight=weight_observer
801-
)
802-
qconfig_mapping = QConfigMapping()
803-
qconfig_mapping.set_global(QConfig(activation=None, weight=None))
804-
qconfig_mapping.set_object_type(torch.nn.Linear, default_qconfig)
805-
self._quantize()
806-
self._test_quantizer(
807-
m,
808-
example_inputs,
809-
quantizer,
810-
node_occurrence,
811-
check_against_fx_quant=True,
812-
fx_qconfig_mapping=qconfig_mapping,
813-
)
814-
# self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, )
815-
816701
def test_save_load(self) -> None:
817702
"""Test save/load a quantized model"""
818703
m = self._get_linear()

0 commit comments

Comments
 (0)