Skip to content

Commit c418a9a

Browse files
ZhiweiYan-96pytorchmergebot
authored andcommitted
[Intel GPU] XPUInductorQuantizer for XPU int8 recipe customization (pytorch#139578)
# Motivation This PR add `XPUInductorQuantizer`, which would defined the recipe of int8 quantization at XPU backend. # Detailed The `XPUInductorQuantizer` is class derived from `X86InductorQuantizer` as both quantizer would take the advantage of highly optimized operators in oneDNN library(qconv, qlinear, qconv/qlinear fusion). We share the same recipe as `X86InductorQuantizer`, so we would have same `annotate_xxxx` methods. So, in ideal situation, the `XPUInductorQuantizer` would have no class body as all implementation can inherit from base class. In this PR, we override the `annotate_xxx` method for operators that has NOT be implemented. All operators XPU backend does not implement would be fallbacked to fp32 implementation as the node in graph is a `dq-op-q` pairs. This would help provide good OOB usability for XPU backend. On the other hand, the implemented operators would uses `annotate_op` implemented in base class and could be lowered successfully. Pull Request resolved: pytorch#139578 Approved by: https://github.com/EikanWang, https://github.com/leslie-fang-intel, https://github.com/CuiYifeng, https://github.com/jerryzh168 ghstack dependencies: pytorch#133080
1 parent 5318bf8 commit c418a9a

File tree

5 files changed

+183
-11
lines changed

5 files changed

+183
-11
lines changed

docs/source/conf.py

+2
Original file line numberDiff line numberDiff line change
@@ -2457,6 +2457,8 @@
24572457
"SharedQuantizationSpec",
24582458
# torch.ao.quantization.quantizer.x86_inductor_quantizer
24592459
"X86InductorQuantizer",
2460+
# torch.ao.quantization.quantizer.xpu_inductor_quantizer
2461+
"XPUInductorQuantizer",
24602462
# torch.ao.quantization.quantizer.xnnpack_quantizer
24612463
"XNNPACKQuantizer",
24622464
# torch.ao.quantization.quantizer.xnnpack_quantizer_utils

docs/source/quantization.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,7 @@ Please take a look at `Limitations of Symbolic Tracing <https://pytorch.org/docs
13531353
.. py:module:: torch.ao.quantization.quantizer.quantizer
13541354
.. py:module:: torch.ao.quantization.quantizer.utils
13551355
.. py:module:: torch.ao.quantization.quantizer.x86_inductor_quantizer
1356+
.. py:module:: torch.ao.quantization.quantizer.xpu_inductor_quantizer
13561357
.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer
13571358
.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer_utils
13581359
.. py:module:: torch.ao.quantization.stubs

torch/_inductor/fx_passes/quantization.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _get_pattern_output_dtype(match: Match):
6363
output_node = pattern_output_nodes[0]
6464
assert isinstance(output_node, torch.fx.Node)
6565
output_dtype = output_node.meta["val"].dtype
66-
assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
66+
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
6767
return output_dtype
6868

6969

@@ -335,10 +335,18 @@ def qconv(match: Match, *args, **kwargs):
335335
kwargs["groups"],
336336
)
337337
output_dtype = _get_pattern_output_dtype(match)
338-
assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
338+
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
339339
# Output QParams
340-
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
341-
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
340+
o_inv_scale = (
341+
kwargs["o_inv_scale"]
342+
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
343+
else 1.0
344+
)
345+
o_zero_point = (
346+
kwargs["o_zp"]
347+
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
348+
else 0
349+
)
342350
assert (
343351
kwargs["attr"] == "none"
344352
) # Expected no post op fused in weight prepack phase
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# mypy: allow-untyped-defs
2+
import functools
3+
from typing import Any, Dict, Optional, TYPE_CHECKING
4+
5+
import torch
6+
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
7+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
8+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
9+
_is_any_annotated,
10+
FilterFn,
11+
int8_in_int8_out_ops,
12+
X86InductorQuantizer,
13+
)
14+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
15+
from torch.fx import Node
16+
17+
18+
if TYPE_CHECKING:
19+
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
20+
21+
__all__ = [
22+
"XPUInductorQuantizer",
23+
"get_default_xpu_inductor_quantization_config",
24+
]
25+
26+
27+
@functools.lru_cache
28+
def get_default_xpu_inductor_quantization_config():
29+
extra_args: Dict[str, Any] = {"eps": 2**-12}
30+
act_observer_or_fake_quant_ctr = HistogramObserver
31+
act_quantization_spec = QuantizationSpec(
32+
dtype=torch.int8,
33+
quant_min=-128,
34+
quant_max=127,
35+
qscheme=torch.per_tensor_affine,
36+
is_dynamic=False,
37+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
38+
**extra_args
39+
),
40+
)
41+
42+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
43+
PerChannelMinMaxObserver
44+
)
45+
46+
weight_quantization_spec = QuantizationSpec(
47+
dtype=torch.int8,
48+
quant_min=-128,
49+
quant_max=127,
50+
qscheme=torch.per_channel_symmetric,
51+
ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
52+
is_dynamic=False,
53+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
54+
**extra_args
55+
),
56+
)
57+
58+
bias_quantization_spec = None # will use placeholder observer by default
59+
quantization_config = QuantizationConfig(
60+
act_quantization_spec,
61+
act_quantization_spec,
62+
weight_quantization_spec,
63+
bias_quantization_spec,
64+
False,
65+
)
66+
return quantization_config
67+
68+
69+
class XPUInductorQuantizer(X86InductorQuantizer):
70+
"""
71+
XPUInductorQuantizer is a class designed to facilitate
72+
quantization capability at Intel GPU backend. The class
73+
highly reuses the existing implementation of
74+
X86InductorQuantizer as both are intended to take advantage
75+
of the optimized kernels in oneDNN library.
76+
"""
77+
78+
def __init__(self) -> None:
79+
super().__init__()
80+
81+
"""
82+
Following annotate_xx overrides the impls in base class, as
83+
no XPU implementation for these operators currently. We would
84+
gradually enable the XPU implementation and remove following
85+
overrides. We keep the annotate methods but make the function
86+
body empty, aiming to let `_generate_qdq_quantized_model`
87+
generate qdq around op and graph execute on fp32 dtype for
88+
unspported operators.
89+
"""
90+
91+
def _annotate_qat_conv2d_fusion_pattern(
92+
self,
93+
model: torch.fx.GraphModule,
94+
quantization_config: Optional[QuantizationConfig],
95+
filter_fn: Optional[FilterFn] = None,
96+
):
97+
pass
98+
99+
def _annotate_conv2d_binary(
100+
self,
101+
gm: torch.fx.GraphModule,
102+
quantization_config: Optional[QuantizationConfig],
103+
filter_fn: Optional[FilterFn] = None,
104+
) -> None:
105+
pass
106+
107+
def _annotate_conv2d_binary_unary(
108+
self,
109+
gm: torch.fx.GraphModule,
110+
quantization_config: Optional[QuantizationConfig],
111+
filter_fn: Optional[FilterFn] = None,
112+
) -> None:
113+
pass
114+
115+
def _annotate_linear_fusion_pattern(
116+
self,
117+
model: torch.fx.GraphModule,
118+
quantization_config: Optional[QuantizationConfig],
119+
filter_fn: Optional[FilterFn] = None,
120+
):
121+
pass
122+
123+
def _annotate_matmul(
124+
self,
125+
model: torch.fx.GraphModule,
126+
quantization_config: Optional[QuantizationConfig],
127+
filter_fn: Optional[FilterFn] = None,
128+
):
129+
pass
130+
131+
def _annotate_maxpool2d(
132+
self,
133+
node: Node,
134+
quantization_config: Optional[QuantizationConfig],
135+
) -> None:
136+
"""
137+
Here we skip the annotate logic for maxpool at XPU backend
138+
as the quantized::max_pool2d is only implemented for CPU.
139+
"""
140+
return
141+
142+
def _annotate_output_for_int8_in_int8_out_pattern(
143+
self,
144+
node: Node,
145+
) -> None:
146+
if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
147+
if node.target == torch.ops.aten.max_pool2d.default:
148+
return
149+
else:
150+
input_node = node.all_input_nodes[0]
151+
self._annotate_output_share_observer_as_input(input_node, node)
152+
return

torch/testing/_internal/common_quantization.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@
7575
from typing import Callable, Tuple, Dict, Any, Union, Type, Optional
7676
import torch._dynamo as torchdynamo
7777
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
78+
import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq
7879
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
80+
from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer
7981
import contextlib
8082

8183
class NodeSpec:
@@ -2940,13 +2942,20 @@ def _generate_qdq_quantized_model(
29402942
mod, inputs, is_qat=False, is_dynamic=False, quantizer=None
29412943
):
29422944

2943-
def get_default_quantizer(is_qat, is_dynamic):
2944-
quantizer = X86InductorQuantizer()
2945-
quantizer.set_global(
2946-
xiq.get_default_x86_inductor_quantization_config(
2947-
is_qat=is_qat, is_dynamic=is_dynamic
2945+
def get_default_quantizer(is_qat, is_dynamic, inputs):
2946+
has_xpu = any(isinstance(input, torch.Tensor) and input.device.type == "xpu"
2947+
for input in inputs)
2948+
if has_xpu:
2949+
quantizer = XPUInductorQuantizer()
2950+
assert (not is_qat) and (not is_dynamic), "QAT and dynamic quantization is not supported at XPU backend currently"
2951+
quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config())
2952+
else:
2953+
quantizer = X86InductorQuantizer()
2954+
quantizer.set_global(
2955+
xiq.get_default_x86_inductor_quantization_config(
2956+
is_qat=is_qat, is_dynamic=is_dynamic
2957+
)
29482958
)
2949-
)
29502959
return quantizer
29512960

29522961
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
@@ -2956,7 +2965,7 @@ def get_default_quantizer(is_qat, is_dynamic):
29562965
inputs,
29572966
).module()
29582967
quantizer = (
2959-
quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic)
2968+
quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic, inputs)
29602969
)
29612970
prepare_model = (
29622971
prepare_qat_pt2e(export_model, quantizer)

0 commit comments

Comments
 (0)