|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -# Note - The `W8A8StaticQuantizer` is aligned with with the pytorch-labs/ao's unified quantization API. |
16 |
| -# https://github.com/pytorch-labs/ao/blob/5401df093564825c06691f4c2c10cdcf1a32a40c/torchao/quantization/unified.py#L15-L26 |
17 | 15 | # Some code snippets are taken from the X86InductorQuantizer tutorial.
|
18 | 16 | # https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html
|
19 | 17 |
|
20 | 18 |
|
21 |
| -from typing import Any, Dict, Optional, Tuple, Union |
| 19 | +from typing import Any |
22 | 20 |
|
23 | 21 | import torch
|
24 | 22 | import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
|
28 | 26 | from torch.fx.graph_module import GraphModule
|
29 | 27 |
|
30 | 28 | from neural_compressor.common.utils import logger
|
31 |
| -from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version |
| 29 | +from neural_compressor.torch.algorithms.base_algorithm import Quantizer |
| 30 | +from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config |
32 | 31 |
|
33 | 32 |
|
34 |
| -class W8A8StaticQuantizer: |
| 33 | +class W8A8StaticQuantizer(Quantizer): |
35 | 34 |
|
36 | 35 | @staticmethod
|
37 |
| - def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer: |
38 |
| - # TODO: add the logic to update the quantizer based on the quant_config |
39 |
| - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) |
| 36 | + def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer: |
| 37 | + if not quant_config: |
| 38 | + quantizer = X86InductorQuantizer() |
| 39 | + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) |
| 40 | + else: |
| 41 | + quantizer = create_xiq_quantizer_from_pt2e_config(quant_config) |
40 | 42 | return quantizer
|
41 | 43 |
|
42 |
| - @staticmethod |
43 |
| - def export_model( |
44 |
| - model, |
45 |
| - example_inputs: Tuple[Any], |
46 |
| - dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, |
47 |
| - ) -> Optional[GraphModule]: |
48 |
| - exported_model = None |
49 |
| - try: |
50 |
| - with torch.no_grad(): |
51 |
| - # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be |
52 |
| - # updated to use the official `torch.export` API when that is ready. |
53 |
| - cur_version = get_torch_version() |
54 |
| - if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover |
55 |
| - logger.warning( |
56 |
| - ( |
57 |
| - "`dynamic_shapes` is not supported in the current version(%s) of PyTorch," |
58 |
| - "If you want to use `dynamic_shapes` to export model, " |
59 |
| - "please upgrade to 2.3.0 or later." |
60 |
| - ), |
61 |
| - cur_version, |
62 |
| - ) |
63 |
| - exported_model = capture_pre_autograd_graph(model, args=example_inputs) |
64 |
| - else: # pragma: no cover |
65 |
| - exported_model = capture_pre_autograd_graph( # pylint: disable=E1123 |
66 |
| - model, args=example_inputs, dynamic_shapes=dynamic_shapes |
67 |
| - ) |
68 |
| - except Exception as e: |
69 |
| - logger.error(f"Failed to export the model: {e}") |
70 |
| - return exported_model |
71 |
| - |
72 |
| - def prepare( |
73 |
| - self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any |
74 |
| - ) -> GraphModule: |
| 44 | + def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule: |
75 | 45 | """Prepare the model for calibration.
|
76 | 46 |
|
77 |
| - There are two steps in this process: |
78 |
| - 1) export the eager model into model with Aten IR. |
79 |
| - 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly. |
| 47 | + Create the `quantizer` according to the `quant_config`, and insert the observers accordingly. |
80 | 48 | """
|
81 |
| - assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}" |
82 |
| - # Set the model to eval mode |
83 |
| - model = model.eval() |
84 |
| - |
85 |
| - # 1) Capture the FX Graph to be quantized |
86 |
| - dynamic_shapes = kwargs.get("dynamic_shapes", None) |
87 |
| - exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes) |
88 |
| - logger.info("Exported the model to Aten IR successfully.") |
89 |
| - if exported_model is None: |
90 |
| - return |
91 |
| - |
92 |
| - # 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly. |
93 |
| - quantizer = X86InductorQuantizer() |
94 |
| - quantizer = self.update_quantizer_based_on_quant_config(quantizer, quant_config) |
95 |
| - prepared_model = prepare_pt2e(exported_model, quantizer) |
| 49 | + quant_config = self.quant_config |
| 50 | + assert model._exported, "The model should be exported before preparing it for calibration." |
| 51 | + quantizer = self.update_quantizer_based_on_quant_config(quant_config) |
| 52 | + prepared_model = prepare_pt2e(model, quantizer) |
96 | 53 | return prepared_model
|
97 | 54 |
|
98 | 55 | def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
|
|
0 commit comments