Skip to content

Commit 43c3580

Browse files
authored
Register pt2e static quantization (#1761)
This PR 1) align the `W8A8StaticQuantizer` with Quantizer, 2) add export API, 3)map the StaticQuantConfig to X86InductorQuantizer's config. --------- Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 4e31b4d commit 43c3580

File tree

11 files changed

+359
-70
lines changed

11 files changed

+359
-70
lines changed

neural_compressor/torch/algorithms/pt2e_quant/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
16+
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer

neural_compressor/torch/algorithms/pt2e_quant/core.py

+16-59
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Note - The `W8A8StaticQuantizer` is aligned with with the pytorch-labs/ao's unified quantization API.
16-
# https://github.com/pytorch-labs/ao/blob/5401df093564825c06691f4c2c10cdcf1a32a40c/torchao/quantization/unified.py#L15-L26
1715
# Some code snippets are taken from the X86InductorQuantizer tutorial.
1816
# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html
1917

2018

21-
from typing import Any, Dict, Optional, Tuple, Union
19+
from typing import Any
2220

2321
import torch
2422
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
@@ -28,71 +26,30 @@
2826
from torch.fx.graph_module import GraphModule
2927

3028
from neural_compressor.common.utils import logger
31-
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
29+
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
30+
from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config
3231

3332

34-
class W8A8StaticQuantizer:
33+
class W8A8StaticQuantizer(Quantizer):
3534

3635
@staticmethod
37-
def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer:
38-
# TODO: add the logic to update the quantizer based on the quant_config
39-
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
36+
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
37+
if not quant_config:
38+
quantizer = X86InductorQuantizer()
39+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
40+
else:
41+
quantizer = create_xiq_quantizer_from_pt2e_config(quant_config)
4042
return quantizer
4143

42-
@staticmethod
43-
def export_model(
44-
model,
45-
example_inputs: Tuple[Any],
46-
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
47-
) -> Optional[GraphModule]:
48-
exported_model = None
49-
try:
50-
with torch.no_grad():
51-
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
52-
# updated to use the official `torch.export` API when that is ready.
53-
cur_version = get_torch_version()
54-
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
55-
logger.warning(
56-
(
57-
"`dynamic_shapes` is not supported in the current version(%s) of PyTorch,"
58-
"If you want to use `dynamic_shapes` to export model, "
59-
"please upgrade to 2.3.0 or later."
60-
),
61-
cur_version,
62-
)
63-
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
64-
else: # pragma: no cover
65-
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
66-
model, args=example_inputs, dynamic_shapes=dynamic_shapes
67-
)
68-
except Exception as e:
69-
logger.error(f"Failed to export the model: {e}")
70-
return exported_model
71-
72-
def prepare(
73-
self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any
74-
) -> GraphModule:
44+
def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:
7545
"""Prepare the model for calibration.
7646
77-
There are two steps in this process:
78-
1) export the eager model into model with Aten IR.
79-
2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
47+
Create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
8048
"""
81-
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
82-
# Set the model to eval mode
83-
model = model.eval()
84-
85-
# 1) Capture the FX Graph to be quantized
86-
dynamic_shapes = kwargs.get("dynamic_shapes", None)
87-
exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
88-
logger.info("Exported the model to Aten IR successfully.")
89-
if exported_model is None:
90-
return
91-
92-
# 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
93-
quantizer = X86InductorQuantizer()
94-
quantizer = self.update_quantizer_based_on_quant_config(quantizer, quant_config)
95-
prepared_model = prepare_pt2e(exported_model, quantizer)
49+
quant_config = self.quant_config
50+
assert model._exported, "The model should be exported before preparing it for calibration."
51+
quantizer = self.update_quantizer_based_on_quant_config(quant_config)
52+
prepared_model = prepare_pt2e(model, quantizer)
9653
return prepared_model
9754

9855
def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from neural_compressor.torch.export._export import export_model_for_pt2e_quant, export
+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, Optional, Tuple, Union
16+
17+
import torch
18+
from torch._export import capture_pre_autograd_graph
19+
from torch.fx.graph_module import GraphModule
20+
21+
from neural_compressor.common.utils import logger
22+
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version, is_ipex_imported
23+
24+
__all__ = ["export", "export_model_for_pt2e_quant"]
25+
26+
27+
def export_model_for_pt2e_quant(
28+
model: torch.nn.Module,
29+
example_inputs: Tuple[Any],
30+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
31+
) -> Optional[GraphModule]:
32+
"""Export the eager model into model with Aten IR."""
33+
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
34+
# Set the model to eval mode
35+
model = model.eval()
36+
exported_model = None
37+
try:
38+
with torch.no_grad():
39+
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
40+
# updated to use the official `torch.export` API when that is ready.
41+
cur_version = get_torch_version()
42+
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
43+
logger.warning(
44+
(
45+
"`dynamic_shapes` is not supported in the current version(%s) of PyTorch,"
46+
"If you want to use `dynamic_shapes` to export model, "
47+
"please upgrade to 2.3.0 or later."
48+
),
49+
cur_version,
50+
)
51+
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
52+
else:
53+
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
54+
model, args=example_inputs, dynamic_shapes=dynamic_shapes
55+
)
56+
exported_model._exported = True
57+
logger.info("Exported the model to Aten IR successfully.")
58+
except Exception as e:
59+
logger.error(f"Failed to export the model: {e}")
60+
61+
return exported_model
62+
63+
64+
def export(
65+
model: torch.nn.Module,
66+
example_inputs: Tuple[Any],
67+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
68+
) -> Optional[GraphModule]:
69+
if not is_ipex_imported():
70+
return export_model_for_pt2e_quant(model, example_inputs, dynamic_shapes)
71+
else:
72+
# TODO, add `export` for ipex
73+
pass

neural_compressor/torch/quantization/algorithm_entry.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
StaticQuantConfig,
3131
TEQConfig,
3232
)
33-
from neural_compressor.torch.utils import Mode, logger, register_algo
33+
from neural_compressor.torch.utils import Mode, is_ipex_imported, logger, register_algo
34+
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT
3435

3536

3637
###################### RTN Algo Entry ##################################
@@ -147,6 +148,8 @@ def static_quant_entry(
147148
*args,
148149
**kwargs,
149150
) -> torch.nn.Module:
151+
if not is_ipex_imported():
152+
return pt2e_static_quant_entry(model, configs_mapping, mode, *args, **kwargs)
150153
logger.info("Quantize model with the static quant algorithm.")
151154
from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer
152155

@@ -191,6 +194,25 @@ def static_quant_entry(
191194
return model
192195

193196

197+
###################### PT2E Static Quant Algo Entry ##################################
198+
@register_algo(name=PT2E_STATIC_QUANT)
199+
@torch.no_grad()
200+
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
201+
logger.info("Quantize model with the PT2E static quant algorithm.")
202+
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
203+
204+
run_fn = kwargs.get("run_fn", None)
205+
example_inputs = kwargs.get("example_inputs", None)
206+
inplace = kwargs.get("inplace", True)
207+
for _, quant_config in configs_mapping.items():
208+
if quant_config.name == STATIC_QUANT:
209+
w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config)
210+
model = w8a8_quantizer.execute(
211+
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
212+
)
213+
return model
214+
215+
194216
###################### Smooth Quant Algo Entry ##################################
195217
@register_algo(name=SMOOTH_QUANT)
196218
@torch.no_grad()

neural_compressor/torch/quantization/config.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# pylint:disable=import-error
1818

1919
from collections import OrderedDict
20-
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
20+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
2121

2222
import torch
2323

@@ -40,7 +40,7 @@
4040
STATIC_QUANT,
4141
TEQ,
4242
)
43-
from neural_compressor.torch.utils import is_hpex_available, logger
43+
from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, logger
4444
from neural_compressor.torch.utils.constants import (
4545
PRIORITY_AUTOROUND,
4646
PRIORITY_AWQ,
@@ -820,19 +820,31 @@ def __init__(
820820
@classmethod
821821
def register_supported_configs(cls) -> List[OperatorConfig]:
822822
supported_configs = []
823-
# TODO(Yi)
824823
linear_static_config = StaticQuantConfig()
825824
operators = [torch.nn.Linear]
826825
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
827826
cls.supported_configs = supported_configs
828827

829828
@staticmethod
830-
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
829+
def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
831830
from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively
832831

833832
_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
834833
return model_info
835834

835+
@staticmethod
836+
def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
837+
if is_ipex_imported():
838+
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)
839+
840+
def to_config_mapping(
841+
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
842+
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
843+
if is_ipex_imported():
844+
return super().to_config_mapping(config_list, model_info)
845+
config_mapping = OrderedDict({self.name: self})
846+
return config_mapping
847+
836848
@classmethod
837849
def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]:
838850
return StaticQuantConfig(act_sym=[True, False], act_algo=["kl", "minmax"])
@@ -844,6 +856,8 @@ def get_default_static_config() -> StaticQuantConfig:
844856
Returns:
845857
the default static quant config.
846858
"""
859+
if not is_ipex_imported():
860+
return StaticQuantConfig(w_granularity="per_tensor")
847861
return StaticQuantConfig()
848862

849863

neural_compressor/torch/utils/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,6 @@
4949
PRIORITY_AWQ = 70
5050
PRIORITY_TEQ = 60
5151
PRIORITY_AUTOROUND = 50
52+
53+
54+
PT2E_STATIC_QUANT = "pt2e_static_quant"

neural_compressor/torch/utils/environ.py

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import sys
17+
1618
import torch
1719
from packaging.version import Version
1820

@@ -65,6 +67,13 @@ def get_torch_version():
6567
return version
6668

6769

70+
def is_ipex_imported() -> bool:
71+
for name, _ in sys.modules.items():
72+
if name == "intel_extension_for_pytorch":
73+
return True
74+
return False
75+
76+
6877
def get_device(device_name="auto"):
6978
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
7079

0 commit comments

Comments
 (0)