Skip to content

Commit cf36f3f

Browse files
[Conformance][TorchFX] GPU quantization support (#3010)
### Changes * --validate-in-backend CLI option is added * CUDA_FX_TORCH backend is added to conformance test * FXSQMultiply is updated to work on both CPU and GPU ### Tests Local run: CLI: `python -m pytest test_quantize_conformance.py -k CUDA_FX --data path/to/imagenet` | Model | Backend | Metric name | Metric value | Metric diff | Num FQ | Num int4 | Num int8 | Compr. time | Total time | RAM MiB | Status | Build url | |-----------------------------------|---------------|-------------|--------------|-------------|--------|----------|----------|-------------|------------|---------|--------|-----------| | torchvision/resnet18 | CUDA_FX_TORCH | Acc@1 | 0.6942 | -0.0036 | 30 | 0 | 21 | 0:00:02 | 0:04:14 | 1560 | | | | torchvision/swin_v2_s | CUDA_FX_TORCH | Acc@1 | 0.83572 | -0.0014 | 149 | 0 | 101 | 0:00:55 | 0:17:24 | 3161 | | | | torchvision/vit_b_16 | CUDA_FX_TORCH | Acc@1 | 0.80962 | -0.00108 | 62 | 0 | 50 | 0:00:19 | 0:13:39 | 2876 | | | | torchvision/mobilenet_v3_small_BC | CUDA_FX_TORCH | Acc@1 | 0.66642 | -0.01018 | 61 | 0 | 36 | 0:00:05 | 0:04:09 | 1653 | | |
1 parent 1fe479c commit cf36f3f

File tree

12 files changed

+208
-46
lines changed

12 files changed

+208
-46
lines changed

nncf/experimental/torch/fx/constant_folding.py

+27-24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import torch.fx
1616
import torch.utils._pytree as pytree
1717

18+
from nncf.torch.utils import get_model_device
19+
1820
aten = torch.ops.aten
1921

2022

@@ -246,28 +248,29 @@ def constant_fold(
246248
:param constraint_fn: Constraint function which takes a node and returs the constraint:
247249
should the node be constant folded or not.
248250
"""
249-
with torch.no_grad():
250-
with torch.utils._python_dispatch._disable_current_modes():
251-
cf = ConstantFolder(gm)
252-
cf.run()
251+
with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
252+
cf = ConstantFolder(gm)
253+
cf.run()
253254

254-
for node, constant in cf.node_replacements.items():
255-
if constraint_fn is not None and not constraint_fn(node):
256-
continue
257-
_replace_node_with_constant(gm, node, constant)
258-
259-
erased_params = []
260-
for node in gm.graph.find_nodes(op="get_attr"):
261-
if len(node.users) == 0:
262-
if hasattr(gm, node.target):
263-
delattr(gm, node.target)
264-
erased_params.append(node)
265-
266-
for node in erased_params:
267-
gm.graph.erase_node(node)
268-
269-
# Custom _is_impure function allows to eliminate all layers with zero
270-
# users including inplace ops like relu_ besides output and placeholders.
271-
gm.graph.eliminate_dead_code(_is_impure)
272-
gm.graph.lint()
273-
gm.recompile()
255+
device = get_model_device(gm)
256+
for node, constant in cf.node_replacements.items():
257+
if constraint_fn is not None and not constraint_fn(node):
258+
continue
259+
constant = constant.to(device)
260+
_replace_node_with_constant(gm, node, constant)
261+
262+
erased_params = []
263+
for node in gm.graph.find_nodes(op="get_attr"):
264+
if len(node.users) == 0:
265+
if hasattr(gm, node.target):
266+
delattr(gm, node.target)
267+
erased_params.append(node)
268+
269+
for node in erased_params:
270+
gm.graph.erase_node(node)
271+
272+
# Custom _is_impure function allows to eliminate all layers with zero
273+
# users including inplace ops like relu_ besides output and placeholders.
274+
gm.graph.eliminate_dead_code(_is_impure)
275+
gm.graph.lint()
276+
gm.recompile()

nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
class FXSQMultiply(torch.nn.Module):
4343
def __init__(self, scale: torch.Tensor):
4444
super().__init__()
45-
self._scale_value = scale
45+
self.register_buffer("_scale_value", scale)
46+
self._scale_value: torch.Tensor
4647

4748
def forward(self, x: torch.Tensor) -> torch.Tensor:
4849
return torch.mul(x, self._scale_value)

tests/post_training/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def pytest_addoption(parser):
1919
parser.addoption("--fp32", action="store_true", help="Test original model")
2020
parser.addoption("--cuda", action="store_true", help="Enable CUDA_TORCH backend")
2121
parser.addoption("--benchmark", action="store_true", help="Run benchmark_app")
22+
parser.addoption(
23+
"--torch-compile-validation",
24+
action="store_true",
25+
help='Validate TorchFX quantized models via torch.compile(..., backend="openvino")',
26+
)
2227
parser.addoption(
2328
"--extra-columns",
2429
action="store_true",

tests/post_training/data/ptq_reference_data.yaml

+8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ torchvision/resnet18_backend_CUDA_TORCH:
3838
metric_value: 0.69152
3939
torchvision/resnet18_backend_FX_TORCH:
4040
metric_value: 0.6946
41+
torchvision/resnet18_backend_CUDA_FX_TORCH:
42+
metric_value: 0.6946
4143
torchvision/mobilenet_v3_small_BC_backend_FP32:
4244
metric_value: 0.6766
4345
torchvision/mobilenet_v3_small_BC_backend_OV:
@@ -46,18 +48,24 @@ torchvision/mobilenet_v3_small_BC_backend_ONNX:
4648
metric_value: 0.6679
4749
torchvision/mobilenet_v3_small_BC_backend_FX_TORCH:
4850
metric_value: 0.6679
51+
torchvision/mobilenet_v3_small_BC_backend_CUDA_FX_TORCH:
52+
metric_value: 0.6664
4953
torchvision/vit_b_16_backend_FP32:
5054
metric_value: 0.8107
5155
torchvision/vit_b_16_backend_OV:
5256
metric_value: 0.80948
5357
torchvision/vit_b_16_backend_FX_TORCH:
5458
metric_value: 0.80922
59+
torchvision/vit_b_16_backend_CUDA_FX_TORCH:
60+
metric_value: 0.80922
5561
torchvision/swin_v2_s_backend_FP32:
5662
metric_value: 0.83712
5763
torchvision/swin_v2_s_backend_OV:
5864
metric_value: 0.83638
5965
torchvision/swin_v2_s_backend_FX_TORCH:
6066
metric_value: 0.8360
67+
torchvision/swin_v2_s_backend_CUDA_FX_TORCH:
68+
metric_value: 0.8360
6169
timm/crossvit_9_240_backend_CUDA_TORCH:
6270
metric_value: 0.7275
6371
timm/crossvit_9_240_backend_FP32:

tests/post_training/model_scope.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,14 @@
8787
"model_id": "resnet18",
8888
"pipeline_cls": ImageClassificationTorchvision,
8989
"compression_params": {},
90-
"backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX],
90+
"backends": [
91+
BackendType.FX_TORCH,
92+
BackendType.CUDA_FX_TORCH,
93+
BackendType.TORCH,
94+
BackendType.CUDA_TORCH,
95+
BackendType.OV,
96+
BackendType.ONNX,
97+
],
9198
"batch_size": 128,
9299
},
93100
{
@@ -98,7 +105,7 @@
98105
"fast_bias_correction": False,
99106
"preset": QuantizationPreset.MIXED,
100107
},
101-
"backends": [BackendType.FX_TORCH, BackendType.OV, BackendType.ONNX],
108+
"backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV, BackendType.ONNX],
102109
"batch_size": 128,
103110
},
104111
{
@@ -109,7 +116,7 @@
109116
"model_type": ModelType.TRANSFORMER,
110117
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.15),
111118
},
112-
"backends": [BackendType.FX_TORCH, BackendType.OV],
119+
"backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV],
113120
"batch_size": 1,
114121
},
115122
{
@@ -120,7 +127,7 @@
120127
"model_type": ModelType.TRANSFORMER,
121128
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.5),
122129
},
123-
"backends": [BackendType.FX_TORCH, BackendType.OV],
130+
"backends": [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH, BackendType.OV],
124131
"batch_size": 1,
125132
},
126133
# Timm models

tests/post_training/pipelines/base.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class BackendType(Enum):
5555
TORCH = "TORCH"
5656
CUDA_TORCH = "CUDA_TORCH"
5757
FX_TORCH = "FX_TORCH"
58+
CUDA_FX_TORCH = "CUDA_FX_TORCH"
5859
ONNX = "ONNX"
5960
OV = "OV"
6061
OPTIMUM = "OPTIMUM"
@@ -63,6 +64,7 @@ class BackendType(Enum):
6364
NNCF_PTQ_BACKENDS = [BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.ONNX, BackendType.OV]
6465
ALL_PTQ_BACKENDS = NNCF_PTQ_BACKENDS
6566
PT_BACKENDS = [BackendType.TORCH, BackendType.CUDA_TORCH]
67+
FX_BACKENDS = [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH]
6668
OV_BACKENDS = [BackendType.OV, BackendType.OPTIMUM]
6769

6870
LIMIT_LENGTH_OF_STATUS = 120
@@ -222,6 +224,7 @@ def __init__(
222224
reference_data: dict,
223225
no_eval: bool,
224226
run_benchmark_app: bool,
227+
torch_compile_validation: bool = False,
225228
params: dict = None,
226229
batch_size: int = 1,
227230
memory_monitor: bool = False,
@@ -238,6 +241,7 @@ def __init__(
238241
self.memory_monitor = memory_monitor
239242
self.no_eval = no_eval
240243
self.run_benchmark_app = run_benchmark_app
244+
self.torch_compile_validation = torch_compile_validation
241245
self.output_model_dir: Path = self.output_dir / self.reported_name / self.backend.value
242246
self.output_model_dir.mkdir(parents=True, exist_ok=True)
243247
self.model_name = f"{self.reported_name}_{self.backend.value}"
@@ -436,12 +440,17 @@ def save_compressed_model(self) -> None:
436440
)
437441
self.path_compressed_ir = self.output_model_dir / "model.xml"
438442
ov.serialize(ov_model, self.path_compressed_ir)
439-
elif self.backend == BackendType.FX_TORCH:
440-
exported_model = torch.export.export(self.compressed_model, (self.dummy_tensor,))
443+
elif self.backend in FX_BACKENDS:
444+
exported_model = torch.export.export(self.compressed_model.cpu(), (self.dummy_tensor.cpu(),))
441445
ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor.cpu(), input=self.input_size)
442446
ov_model.reshape(self.input_size)
443447
self.path_compressed_ir = self.output_model_dir / "model.xml"
444448
ov.serialize(ov_model, self.path_compressed_ir)
449+
450+
if BackendType.CUDA_FX_TORCH:
451+
self.model = self.model.cuda()
452+
self.dummy_tensor = self.dummy_tensor.cuda()
453+
445454
elif self.backend == BackendType.ONNX:
446455
onnx_path = self.output_model_dir / "model.onnx"
447456
onnx.save(self.compressed_model, str(onnx_path))

tests/post_training/pipelines/image_classification_base.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import nncf
2323
from nncf.common.logging.track_progress import track
2424
from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS
25+
from tests.post_training.pipelines.base import FX_BACKENDS
2526
from tests.post_training.pipelines.base import ErrorReport
2627
from tests.post_training.pipelines.base import PTQTestPipeline
2728

@@ -35,18 +36,15 @@ def prepare_calibration_dataset(self):
3536

3637
self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn())
3738

38-
def _validate(self) -> List[ErrorReport]:
39-
val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
40-
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False)
41-
42-
dataset_size = len(val_loader)
43-
44-
# Initialize result tensors for async inference support.
45-
predictions = np.zeros(dataset_size)
46-
references = -1 * np.ones(dataset_size)
39+
def _validate_ov(
40+
self,
41+
val_loader: torch.utils.data.DataLoader,
42+
predictions: np.ndarray,
43+
references: np.ndarray,
44+
dataset_size: int,
45+
):
4746

4847
core = ov.Core()
49-
5048
if os.environ.get("INFERENCE_NUM_THREADS"):
5149
# Set CPU_THREADS_NUM for OpenVINO inference
5250
inference_num_threads = os.environ.get("INFERENCE_NUM_THREADS")
@@ -75,6 +73,34 @@ def process_result(request, userdata):
7573
references[i] = target
7674

7775
infer_queue.wait_all()
76+
return predictions, references
77+
78+
def _validate_torch_compile(
79+
self, val_loader: torch.utils.data.DataLoader, predictions: np.ndarray, references: np.ndarray
80+
):
81+
compiled_model = torch.compile(self.compressed_model.cpu(), backend="openvino")
82+
for i, (images, target) in enumerate(val_loader):
83+
# W/A for memory leaks when using torch DataLoader and OpenVINO
84+
pred = compiled_model(images)
85+
pred = torch.argmax(pred, dim=1)
86+
predictions[i] = pred.numpy()
87+
references[i] = target.numpy()
88+
return predictions, references
89+
90+
def _validate(self) -> List[ErrorReport]:
91+
val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
92+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False)
93+
94+
dataset_size = len(val_loader)
95+
96+
# Initialize result tensors for async inference support.
97+
predictions = np.zeros(dataset_size)
98+
references = -1 * np.ones(dataset_size)
99+
100+
if self.backend in FX_BACKENDS and self.torch_compile_validation:
101+
predictions, references = self._validate_torch_compile(val_loader, predictions, references)
102+
else:
103+
predictions, references = self._validate_ov(val_loader, predictions, references, dataset_size)
78104

79105
acc_top1 = accuracy_score(predictions, references)
80106

tests/post_training/pipelines/image_classification_torchvision.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchvision import models
2020

2121
from nncf.torch import disable_patching
22+
from tests.post_training.pipelines.base import FX_BACKENDS
2223
from tests.post_training.pipelines.base import PT_BACKENDS
2324
from tests.post_training.pipelines.base import BackendType
2425
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase
@@ -74,9 +75,12 @@ def prepare_model(self) -> None:
7475
if self.batch_size > 1: # Dynamic batch_size shape export
7576
self.input_size[0] = -1
7677

77-
if self.backend == BackendType.FX_TORCH:
78+
if self.backend in FX_BACKENDS:
7879
with torch.no_grad():
7980
with disable_patching():
81+
if self.backend is BackendType.CUDA_FX_TORCH:
82+
model = model.cuda()
83+
self.dummy_tensor = self.dummy_tensor.cuda()
8084
self.model = self.model_params.export_fn(model, (self.dummy_tensor,))
8185

8286
elif self.backend in PT_BACKENDS:
@@ -120,20 +124,26 @@ def _dump_model_fp32(self) -> None:
120124
)
121125
ov.serialize(ov_model, self.fp32_model_dir / "model_fp32.xml")
122126

123-
if self.backend == BackendType.FX_TORCH:
124-
exported_model = torch.export.export(self.model, (self.dummy_tensor,))
127+
if self.backend in FX_BACKENDS:
128+
exported_model = torch.export.export(self.model.cpu(), (self.dummy_tensor.cpu(),))
125129
ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size)
126130
ov.serialize(ov_model, self.fp32_model_dir / "fx_model_fp32.xml")
127131

132+
if self.backend is BackendType.CUDA_FX_TORCH:
133+
self.model = self.model.cuda()
134+
self.dummy_tensor = self.dummy_tensor.cuda()
135+
128136
if self.backend in [BackendType.FP32, BackendType.OV]:
129137
ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml")
130138

131139
def prepare_preprocessor(self) -> None:
132140
self.transform = self.model_params.weights.transforms()
133141

134142
def get_transform_calibration_fn(self):
135-
if self.backend in [BackendType.FX_TORCH] + PT_BACKENDS:
136-
device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu")
143+
if self.backend in FX_BACKENDS + PT_BACKENDS:
144+
device = torch.device(
145+
"cuda" if self.backend in [BackendType.CUDA_TORCH, BackendType.CUDA_FX_TORCH] else "cpu"
146+
)
137147

138148
def transform_fn(data_item):
139149
images, _ = data_item

tests/post_training/test_quantize_conformance.py

+7
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def fixture_run_benchmark_app(pytestconfig):
9090
return pytestconfig.getoption("benchmark")
9191

9292

93+
@pytest.fixture(scope="session", name="torch_compile_validation")
94+
def fixture_torch_compile_validation(pytestconfig):
95+
return pytestconfig.getoption("torch_compile_validation")
96+
97+
9398
@pytest.fixture(scope="session", name="extra_columns")
9499
def fixture_extra_columns(pytestconfig):
95100
return pytestconfig.getoption("extra_columns")
@@ -281,6 +286,7 @@ def test_ptq_quantization(
281286
run_torch_cuda_backend: bool,
282287
subset_size: Optional[int],
283288
run_benchmark_app: bool,
289+
torch_compile_validation: bool,
284290
capsys: pytest.CaptureFixture,
285291
extra_columns: bool,
286292
memory_monitor: bool,
@@ -309,6 +315,7 @@ def test_ptq_quantization(
309315
"data_dir": data_dir,
310316
"no_eval": no_eval,
311317
"run_benchmark_app": run_benchmark_app,
318+
"torch_compile_validation": torch_compile_validation,
312319
"batch_size": batch_size,
313320
"memory_monitor": memory_monitor,
314321
}

0 commit comments

Comments
 (0)