Skip to content

Commit 47dbe38

Browse files
committed
strip in one commit
1 parent 5f4378e commit 47dbe38

File tree

5 files changed

+285
-3
lines changed

5 files changed

+285
-3
lines changed

nncf/torch/quantization/strip.py

+170-2
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,37 @@
1010
# limitations under the License.
1111

1212

13+
from typing import List
14+
1315
import numpy as np
1416
import torch
1517
from torch.quantization.fake_quantize import FakeQuantize
1618

1719
import nncf
20+
from nncf.common.graph.transformations.commands import Command
21+
from nncf.common.graph.transformations.commands import TargetType
22+
from nncf.common.graph.transformations.layout import TransformationLayout
23+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
24+
from nncf.experimental.torch2.commands import PT2InsertionCommand
25+
from nncf.torch.dynamic_graph.scope import Scope
1826
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
27+
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
28+
from nncf.torch.graph.transformations.commands import PTTargetPoint
29+
from nncf.torch.model_graph_manager import get_const_node
30+
from nncf.torch.model_graph_manager import get_module_by_name
31+
from nncf.torch.model_graph_manager import split_const_name
32+
from nncf.torch.model_transformer import PTModelTransformer
1933
from nncf.torch.nncf_network import NNCFNetwork
34+
from nncf.torch.quantization.layers import AsymmetricLoraQuantizer
2035
from nncf.torch.quantization.layers import AsymmetricQuantizer
2136
from nncf.torch.quantization.layers import BaseQuantizer
37+
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
38+
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
39+
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
40+
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor
41+
from nncf.torch.quantization.layers import SymmetricLoraQuantizer
2242
from nncf.torch.quantization.layers import SymmetricQuantizer
43+
from nncf.torch.quantization.quantize_functions import TuneRange
2344

2445
SUPPORTED_NUM_BITS_FOR_STRIP_MODEL = [8]
2546

@@ -171,6 +192,153 @@ def strip_quantized_model(model: NNCFNetwork):
171192
:param model: Compressed model.
172193
:return: The modified NNCF network.
173194
"""
174-
model = replace_quantizer_to_torch_native_module(model)
175-
model = remove_disabled_quantizers(model)
195+
model_layout = model.nncf.transformation_layout()
196+
transformations = model_layout.transformations
197+
if any([type(q.fn) in [AsymmetricLoraQuantizer, SymmetricLoraQuantizer] for q in transformations]):
198+
model = replace_with_decompressors(model, transformations)
199+
else:
200+
model = replace_quantizer_to_torch_native_module(model)
201+
model = remove_disabled_quantizers(model)
176202
return model
203+
204+
205+
def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command]) -> NNCFNetwork:
206+
"""
207+
Performs transformation from fake quantize format (FQ) to dequantization one (DQ).
208+
The former takes floating-point input, quantizes and dequantizes, and returns a floating-point value,
209+
while the latter takes a quantized integer representation, dequantizes it, and outputs a floating-point result.
210+
211+
Mathematically, both methods lead to the same outcome, but due to differences in the order of operations and
212+
rounding errors, the actual results may differ. In particular, this error can occur for values
213+
that are located in the midpoint between two quantized values ("quants").
214+
215+
The FQ format may round these values to one "quant", while the DQ format rounds them to another "quant".
216+
To avoid these issues, the compressed representation should be provided not by directly quantizing the input,
217+
but by quantizing a pre-processed, fake-quantized, floating-point representation.
218+
219+
:param model: Compressed model with Decompressors.
220+
:return: The modified NNCF network.
221+
"""
222+
transformation_layout = TransformationLayout()
223+
model = model.nncf.get_clean_shallow_copy()
224+
graph = model.nncf.get_graph()
225+
226+
for command in transformations:
227+
quantizer = command.fn
228+
229+
if len(command.target_points) > 1:
230+
msg = "Command contains more than one target point!"
231+
raise nncf.ValidationError(msg)
232+
233+
tp = command.target_points[0]
234+
node_with_weight = graph.get_node_by_name(tp.target_node_name)
235+
weight_node = get_const_node(node_with_weight, tp.input_port_id, graph)
236+
237+
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
238+
module = get_module_by_name(module_name, model)
239+
original_weight = getattr(module, weight_attr_name)
240+
241+
original_dtype = original_weight.dtype
242+
original_shape = original_weight.shape
243+
original_eps = torch.finfo(original_dtype).eps
244+
245+
qdq_weight = quantizer.quantize(original_weight)
246+
if hasattr(quantizer, "_lspec"):
247+
# Special reshape for LoRA-grouped output
248+
qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape)
249+
qdq_weight = qdq_weight.to(original_dtype)
250+
251+
if isinstance(quantizer, AsymmetricQuantizer):
252+
input_range_safe = abs(quantizer.input_range) + quantizer.eps
253+
input_low, input_range = TuneRange.apply(quantizer.input_low, input_range_safe, quantizer.levels)
254+
255+
integer_dtype = torch.uint8
256+
257+
input_low = input_low.to(original_dtype)
258+
input_range = input_range.to(original_dtype)
259+
260+
scale = input_range / quantizer.level_high
261+
scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale)
262+
scale = scale.to(original_dtype)
263+
264+
zero_point = quantizer.level_low - torch.round(input_low / scale)
265+
zero_point = torch.clip(zero_point, quantizer.level_low, quantizer.level_high)
266+
zero_point = zero_point.to(integer_dtype)
267+
268+
q_weight = qdq_weight / scale
269+
q_weight = q_weight + zero_point
270+
q_weight = torch.round(q_weight)
271+
q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high)
272+
q_weight = q_weight.to(integer_dtype)
273+
274+
if quantizer.num_bits == 8:
275+
decompressor = INT8AsymmetricWeightsDecompressor(
276+
scale=scale, zero_point=zero_point, result_dtype=original_dtype
277+
)
278+
else:
279+
decompressor = INT4AsymmetricWeightsDecompressor(
280+
scale=scale,
281+
zero_point=zero_point,
282+
compressed_weight_shape=q_weight.shape,
283+
result_shape=original_shape,
284+
result_dtype=original_dtype,
285+
)
286+
287+
elif isinstance(quantizer, SymmetricQuantizer):
288+
integer_dtype = torch.int8
289+
290+
scale = quantizer.scale / abs(quantizer.level_low)
291+
scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale)
292+
scale = scale.to(original_dtype)
293+
294+
q_weight = qdq_weight / scale
295+
q_weight = torch.round(q_weight)
296+
q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high)
297+
q_weight = q_weight.to(integer_dtype)
298+
299+
if quantizer.num_bits == 8:
300+
decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=original_dtype)
301+
else:
302+
decompressor = INT4SymmetricWeightsDecompressor(
303+
scale=scale,
304+
compressed_weight_shape=q_weight.shape,
305+
result_shape=original_shape,
306+
result_dtype=original_dtype,
307+
)
308+
309+
packed_tensor = decompressor.pack_weight(q_weight)
310+
311+
# sets compressed tensor
312+
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
313+
setattr(module, weight_attr_name, compressed_parameter)
314+
315+
consumer_nodes = graph.get_next_nodes(weight_node)
316+
if len(consumer_nodes) > 1:
317+
for consumer_node in consumer_nodes:
318+
consumer_module = model.nncf.get_module_by_scope(Scope.from_str(consumer_node.layer_name))
319+
for name, param in consumer_module.named_parameters(recurse=False, remove_duplicate=False):
320+
if id(param) == id(original_weight):
321+
setattr(consumer_module, name, compressed_parameter)
322+
323+
if is_experimental_torch_tracing_enabled():
324+
transformation_layout.register(
325+
PT2InsertionCommand(
326+
[
327+
PTTargetPoint(
328+
TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":")
329+
)
330+
],
331+
decompressor,
332+
)
333+
)
334+
else:
335+
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"
336+
transformation_layout.register(
337+
PTSharedFnInsertionCommand(
338+
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)],
339+
decompressor,
340+
decompressor_name,
341+
)
342+
)
343+
344+
return PTModelTransformer(model).transform(transformation_layout)

tests/torch/helpers.py

+11
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,14 @@ def _check_pre_post_hooks(
773773
assert len(actual_hooks) == len(ref_hooks)
774774
for actual_hook, ref_hook in zip(actual_hooks, ref_hooks):
775775
assert actual_hook is ref_hook
776+
777+
778+
class LinearModel(nn.Module):
779+
def __init__(self, input_shape=List[int]):
780+
super().__init__()
781+
with set_torch_seed():
782+
self.linear = nn.Linear(input_shape[1], input_shape[0], bias=False)
783+
self.linear.weight.data = torch.randn(input_shape) - 0.5
784+
785+
def forward(self, x):
786+
return self.linear(x)

tests/torch/ptq/test_fq_lora.py

+61-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111

1212
import pytest
1313
import torch
14+
from optimum.exporters.openvino.convert import export_from_model
15+
from optimum.intel.openvino import OVModelForCausalLM
16+
from sentence_transformers import SentenceTransformer
17+
from sentence_transformers import util
1418
from transformers import AutoModelForCausalLM
1519
from transformers import AutoTokenizer
1620

@@ -20,6 +24,44 @@
2024
from nncf.torch.quantization.layers import SymmetricQuantizer as SQ
2125

2226

27+
class ValidationMock:
28+
def __init__(self) -> None:
29+
model_id = "sentence-transformers/all-mpnet-base-v2"
30+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
31+
self.model = SentenceTransformer(
32+
model_id, tokenizer_kwargs={"pad_token": self.tokenizer.pad_token}, trust_remote_code=True
33+
)
34+
35+
def calculate_similarity(self, gold: str, prediction: str) -> torch.Tensor:
36+
embeddings = self.model.encode([gold, prediction])
37+
cos_sim = util.cos_sim(embeddings, embeddings)
38+
return torch.mean(cos_sim)
39+
40+
@property
41+
def validation_ref(self) -> torch.Tensor:
42+
return torch.tensor(1.0)
43+
44+
45+
def generate_control_output(model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> torch.Tensor:
46+
control_input = tokenizer("What is Pytorch?", return_tensors="pt")
47+
control_input = control_input.to(model.device)
48+
control_output = model.generate(**control_input, do_sample=False)
49+
return tokenizer.batch_decode(control_output, skip_special_tokens=True)[0]
50+
51+
52+
def get_ov_model(model: AutoModelForCausalLM, tmp_path: str) -> OVModelForCausalLM:
53+
model = model.cpu()
54+
export_from_model(model, tmp_path)
55+
56+
return OVModelForCausalLM.from_pretrained(
57+
model_id=tmp_path,
58+
trust_remote_code=True,
59+
load_in_8bit=False,
60+
compile=True,
61+
ov_config={"KV_CACHE_PRECISION": "f16", "DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"},
62+
)
63+
64+
2365
@pytest.mark.parametrize(
2466
"compression_kwargs",
2567
(dict(scale_estimation=True, awq=True), dict(scale_estimation=False, awq=False)),
@@ -33,7 +75,7 @@
3375
),
3476
ids=["asym", "sym"],
3577
)
36-
def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable, _seed):
78+
def test_fq_lora_tuning(tmp_path, mode, backup_mode, compression_kwargs, ref_num_trainable, _seed):
3779
model_id = "facebook/opt-125m"
3880
device = "cuda" if torch.cuda.is_available() else "cpu"
3981
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device)
@@ -80,3 +122,21 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable
80122

81123
assert first_loss > 8
82124
assert float(loss) < 1
125+
126+
tuned_output = generate_control_output(model, tokenizer)
127+
128+
# Workaround till export from the optimum would be fixed - CVS-164159
129+
model = model.to(torch.float32)
130+
131+
model = nncf.strip(model)
132+
stripped_output = generate_control_output(model, tokenizer)
133+
134+
model = get_ov_model(model, tmp_path)
135+
stripped_ov_output = generate_control_output(model, tokenizer)
136+
137+
vm = ValidationMock()
138+
tuned_vs_stripped = vm.calculate_similarity(tuned_output, stripped_output)
139+
tuned_vs_stripped_ov = vm.calculate_similarity(tuned_output, stripped_ov_output)
140+
141+
assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=0.01)
142+
assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=0.01)

tests/torch/quantization/test_strip.py

+39
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tests.common.quantization.data_generators import generate_sweep_data
3535
from tests.common.quantization.data_generators import get_quant_len_by_range
3636
from tests.torch.helpers import BasicConvTestModel
37+
from tests.torch.helpers import LinearModel
3738
from tests.torch.helpers import create_compressed_model_and_algo_for_test
3839
from tests.torch.helpers import register_bn_adaptation_init_args
3940
from tests.torch.quantization.test_functions import get_test_data
@@ -325,3 +326,41 @@ def test_nncf_strip_api(strip_type, do_copy):
325326

326327
assert isinstance(strip_model.conv.get_pre_op("0").op, FakeQuantize)
327328
assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize)
329+
330+
331+
@pytest.mark.parametrize(
332+
("mode", "torch_dtype", "atol"),
333+
(
334+
(nncf.CompressWeightsMode.INT4_ASYM, torch.float32, 0.0005),
335+
(nncf.CompressWeightsMode.INT4_ASYM, torch.float16, 0.0005),
336+
(nncf.CompressWeightsMode.INT4_ASYM, torch.bfloat16, 0.01),
337+
(nncf.CompressWeightsMode.INT4_SYM, torch.float32, 0.0005),
338+
(nncf.CompressWeightsMode.INT4_SYM, torch.float16, 0.0005),
339+
(nncf.CompressWeightsMode.INT4_SYM, torch.bfloat16, 0.01),
340+
),
341+
)
342+
def test_nncf_strip_lora_model(mode, torch_dtype, atol):
343+
input_shape = [1, 16]
344+
model = LinearModel(input_shape=input_shape)
345+
model = model.to(torch_dtype)
346+
with torch.no_grad():
347+
example = torch.ones(input_shape).to(torch_dtype)
348+
dataset = [example]
349+
350+
compressed_model = nncf.compress_weights(
351+
model,
352+
ratio=1,
353+
group_size=4,
354+
mode=mode,
355+
backup_mode=None,
356+
dataset=nncf.Dataset(dataset),
357+
all_layers=True,
358+
compression_format=nncf.CompressionFormat.FQ_LORA,
359+
)
360+
361+
compressed_output = compressed_model(example)
362+
363+
strip_compressed_model = nncf.strip(compressed_model, do_copy=True)
364+
stripped_output = strip_compressed_model(example)
365+
366+
assert torch.allclose(compressed_output, stripped_output, atol=atol)

tests/torch/requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ timm==0.9.2
2424
# Required for torch/fx tests
2525
torchvision
2626
fastdownload==0.0.7
27+
28+
sentence-transformers>=2.2.2
29+
optimum-intel==1.22.0
30+
optimum==1.24.0

0 commit comments

Comments
 (0)