Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF4 per-channel support for AWQ and Scale Estimation #2898

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,7 @@ def apply(
self._set_weight_compression_config(ratio_defining_params, model, graph, activations)
nncf_logger.info(self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params))

if (
self._awq
and activations is not None
and self._mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]
):
if self._awq and activations is not None and self._mode != CompressWeightsMode.E2M1:
awq_params = self._advanced_parameters.awq_params
awq_algo = AWQ(
model,
Expand Down Expand Up @@ -399,11 +395,7 @@ def apply(
backend_entity=self._backend_entity,
)
else:
if (
self._scale_estimation
and activations is not None
and self._mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]
):
if self._scale_estimation and activations is not None and self._mode != CompressWeightsMode.E2M1:
scale_estimation_params = self._advanced_parameters.scale_estimation_params
scale_algo = ScaleEstimation(
model,
Expand Down
19 changes: 14 additions & 5 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
from nncf.quantization.passes import transform_to_inference_graph
from nncf.tensor import functions as fns

Expand Down Expand Up @@ -244,11 +248,16 @@ def apply(
alpha = self._alpha_min
for _ in range(self._steps):
cur_scale = gscale**alpha

g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization(
gweight * cur_scale, reduction_axis, awq_config
)
g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
weights_to_fake_quantize = gweight * cur_scale
if config.mode == CompressWeightsMode.NF4:
g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis)
g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale)
g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale)
else:
g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization(
weights_to_fake_quantize, reduction_axis, awq_config
)
g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@ def get_compress_decompress_pipeline(config: WeightCompressionConfig, w_shape, s
@staticmethod
def get_compress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None, return_nodes=False):
mode = config.mode
assert mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]
assert mode in [
CompressWeightsMode.INT4_SYM,
CompressWeightsMode.INT4_ASYM,
], f"Only int4 supported, but given={mode}"
num_bits = config.num_bits

asym_quant = mode in [CompressWeightsMode.INT4_ASYM]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_normalized_weight_and_fp4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
Expand Down Expand Up @@ -206,11 +210,18 @@ def calculate_quantization_params(
cur_config.group_size = group_size

original_weight = fns.zeros_like(weight) + weight

compressed_weights, scale, zp = do_int_quantization(original_weight, reduction_axis, cur_config)
if zp is not None:
zp = zp.astype(scale.dtype)
q_weights = do_int_dequantization(compressed_weights, scale, zp, reduction_axis)
if config.mode == CompressWeightsMode.NF4:
norm_weight, scale = calculate_normalized_weight_and_fp4_scale(
original_weight, reduction_axis, cur_config.group_size
)
compressed_weights = do_nf4_quantization(norm_weight, scale, is_normalized_weight=True)
q_weights = do_nf4_dequantization(compressed_weights, scale, reduction_axis)
zp = None
else:
compressed_weights, scale, zp = do_int_quantization(original_weight, reduction_axis, cur_config)
if zp is not None:
zp = zp.astype(scale.dtype)
q_weights = do_int_dequantization(compressed_weights, scale, zp, reduction_axis)

s = fns.unsqueeze(s, 0)
s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, group_size)
Expand Down Expand Up @@ -246,18 +257,19 @@ def calculate_quantization_params(
key = (config.mode, config.num_bits) + q_weights.shape + scale.shape
if zp is not None:
key += zp_shape
if key in ScaleEstimation.compress_decompress_cache:
compress_decompress_model = ScaleEstimation.compress_decompress_cache[key]["compress_decompress_model"]
compress_model = ScaleEstimation.compress_decompress_cache[key]["compress_model"]
else:
compress_decompress_model = backend_entity.get_compress_decompress_pipeline(
config, q_weights.shape, scale.shape, zp_shape
)
compress_model = backend_entity.get_compress_pipeline(config, q_weights.shape, scale.shape, zp_shape)
ScaleEstimation.compress_decompress_cache[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}
if config.mode != CompressWeightsMode.NF4:
if key in ScaleEstimation.compress_decompress_cache:
compress_decompress_model = ScaleEstimation.compress_decompress_cache[key]["compress_decompress_model"]
compress_model = ScaleEstimation.compress_decompress_cache[key]["compress_model"]
else:
compress_decompress_model = backend_entity.get_compress_decompress_pipeline(
config, q_weights.shape, scale.shape, zp_shape
)
compress_model = backend_entity.get_compress_pipeline(config, q_weights.shape, scale.shape, zp_shape)
ScaleEstimation.compress_decompress_cache[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}
scale_sign = scale / fns.abs(scale)
zero_scale = 0.001
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
Expand All @@ -271,7 +283,11 @@ def calculate_quantization_params(
near_to_ideal_scale = near_to_ideal_scale * scale_sign
input_tensors[1] = near_to_ideal_scale.data

out = compress_decompress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale)
out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale)
else:
out = compress_decompress_model(input_tensors)
q_weights_ = fns.zeros_like(original_weight) + out
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)

Expand All @@ -297,7 +313,10 @@ def calculate_quantization_params(
input_tensors[1] = near_to_ideal_scale.data

if i < initial_steps - 1:
out = compress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
out = do_nf4_quantization(original_weight, near_to_ideal_scale)
else:
out = compress_model(input_tensors)
compressed_weights = fns.zeros_like(original_weight) + out
target, zero_mask = get_target_zero_mask(compressed_weights, zp)
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
Expand All @@ -308,7 +327,10 @@ def calculate_quantization_params(
scaled_scale = factor * scale

input_tensors[1] = scaled_scale.data
out = compress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
out = do_nf4_quantization(original_weight, scaled_scale)
else:
out = compress_model(input_tensors)
compressed_weights = fns.zeros_like(original_weight) + out

target, zero_mask = get_target_zero_mask(compressed_weights, zp)
Expand All @@ -317,7 +339,11 @@ def calculate_quantization_params(
near_to_ideal_scale = near_to_ideal_scale * scale_sign

input_tensors[1] = near_to_ideal_scale.data
out = compress_decompress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale)
out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale)
else:
out = compress_decompress_model(input_tensors)
q_weights_ = fns.zeros_like(original_weight) + out

q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,39 +178,37 @@ def calculate_normalized_weight(weight: Tensor, scale: Tensor) -> Tensor:

def do_nf4_quantization(weight: Tensor, scale: Tensor, is_normalized_weight: bool = False) -> Tensor:
"""
Performs NF4 quantization - the floating point value is represented by floating point scale, look-up table of
16 NF4 values Quantizes the weight tensor to NF4 format.
Performs NF4 quantization. The floating point values are represented by floating point scale and look-up with
16 floating-point values on [-1, 1]. Scale normalizes original values to [-1, 1] interval and look-up table
"rounds" or "quantize" to the closest quant.

:param weight: Weight tensor to quantize.
:param scale: Scale tensor used for normalization.
:param is_normalized_weight: Whether weight was scaled to [-1, 1] interval. Defaults to False.
:return: Tensor of indexes from 0 to 15 that represents the position in look-up table with the corresponding
NF4 values from -1 to 1.
:return: Tensor with floating-point values, where each of them corresponds to 1 out of 16 quants on [-1, 1].
"""
norm_weight = weight if is_normalized_weight else calculate_normalized_weight(weight, scale)
center_nf4_quantiles = fns.from_numpy(CENTER_OF_NF4_QUANTILES, backend=norm_weight.backend)
indexes = fns.searchsorted(center_nf4_quantiles, norm_weight)
return indexes
nf4_quantiles = fns.from_numpy(NF4_QUANTILES, backend=indexes.backend)
nf4_weight = nf4_quantiles[indexes]
return nf4_weight


def do_nf4_dequantization(indexes: Tensor, scale: Tensor, reduction_axis: int = -1) -> Tensor:
def do_nf4_dequantization(nf4_weight: Tensor, scale: Tensor, reduction_axis: int = -1) -> Tensor:
"""
Decompresses the NF4 quantized weight tensor.

:param indexes: Tensor of indexes from 0 to 15 that represents the position in look-up table with the corresponding
NF4 values from -1 to 1.
:param nf4_weight: Tensor with floating-point values,
where each of them corresponds to 1 out of 16 quants on [-1, 1].
:param scale: Scale tensor used for decompression.
:param reduction_axis: axis along which weights were reshaped for group quantization and will be reshaped back to
original shapes. If equals to -1, weights are not reshaped, assumed not a group quantization. Defaults to -1.
:return: Decompressed weight tensor.
"""
nf4_quantiles = fns.from_numpy(NF4_QUANTILES, backend=indexes.backend)
nf4_weight = nf4_quantiles[indexes]

decompressed_weight = nf4_weight * scale
if reduction_axis != -1:
decompressed_weight = ungroup_weights(decompressed_weight, reduction_axis)

return decompressed_weight


Expand Down
8 changes: 2 additions & 6 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,8 @@ def compress_weights(
if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl

if any((awq, scale_estimation)) and (
dataset is None or mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]
):
raise AttributeError(
"Scale estimation or AWQ algorithm is defined, but dataset is None or mode is (NF4 or E2M1)."
)
if any((awq, scale_estimation)) and (dataset is None or mode == CompressWeightsMode.E2M1):
raise AttributeError("Scale estimation or AWQ algorithm is defined, but dataset is None or mode is E2M1.")
if any((gptq, lora_correction)) and (dataset is None or mode == CompressWeightsMode.E2M1):
raise AttributeError("GPTQ or Lora Correction algorithm is defined, but dataset is None or mode is E2M1.")

Expand Down
32 changes: 24 additions & 8 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization
from nncf.scopes import IgnoredScope
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
from tests.cross_fw.shared.helpers import compare_stats
from tests.cross_fw.shared.helpers import dump_to_json
from tests.cross_fw.shared.helpers import load_json
Expand Down Expand Up @@ -710,7 +711,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params):
compress_weights(ov.Model([], []), mode=mode, **params)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
@pytest.mark.parametrize(
"params",
({"dataset": "anything", "lora_correction": True, "gptq": True},),
Expand Down Expand Up @@ -748,7 +749,7 @@ def test_call_max_var_criterion_with_dataset_by_default_awq(mode):
compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
@pytest.mark.parametrize("with_multiply", (True, False))
def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(mode, with_multiply):
n_layers = 8
Expand All @@ -765,15 +766,15 @@ def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(mode, wit
assert awq_num == n_awq_target


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_awq_for_compressed_model(mode):
model = AWQMatmulModel(is_int8=True).ov_model
dataset = Dataset([np.ones([8, 8])])

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_awq_neg_group_size(mode):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])
Expand Down Expand Up @@ -875,23 +876,38 @@ def test_duplicate_names_generation():
op_names.add(name)


@pytest.mark.parametrize("mode", INT4_MODES)
def test_call_max_var_criterion_with_dataset_by_default_scale_estimation(mode):
@pytest.mark.parametrize(
("mode", "compressed_weight_dtype"),
(
(CompressWeightsMode.INT4_SYM, TensorDataType.int8),
(CompressWeightsMode.INT4_ASYM, TensorDataType.uint8),
(CompressWeightsMode.NF4, TensorDataType.float32),
),
)
def test_call_max_var_criterion_with_dataset_by_default_scale_estimation(mode, compressed_weight_dtype, mocker):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])
from nncf.quantization.algorithms.weight_compression import scale_estimation
from nncf.quantization.algorithms.weight_compression.algorithm import ScaleEstimation

se_spy = mocker.spy(ScaleEstimation, "apply")
tzm_spy = mocker.spy(scale_estimation, "get_target_zero_mask")

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True)

assert se_spy.call_count == 1
assert tzm_spy.call_args_list[0][0][0].dtype == compressed_weight_dtype

@pytest.mark.parametrize("mode", INT4_MODES)

@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_scale_estimation_for_compressed_model(mode):
model = AWQMatmulModel(is_int8=True).ov_model
dataset = Dataset([np.ones([8, 8])])

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_scale_estimation_neg_group_size(mode):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])
Expand Down
6 changes: 5 additions & 1 deletion tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ tinyllama_scale_estimation_per_channel_backend_OV:
tinyllama_data_aware_lora_stateful_backend_OV:
metric_value: 0.83446
num_int4: 94
num_int8: 500
num_int8: 500
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
metric_value: 0.88663
num_int4: 11
num_int8: 290
5 changes: 5 additions & 0 deletions tests/post_training/data/wc_reference_data_2024.4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ tinyllama_scale_estimation_per_channel_backend_OV:
metric_value: 0.80853
num_int4: 188
num_int8: 124
metrics_xfail_reason: "Issue-148819"
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
metric_value: 0.87942
num_int4: 11
num_int8: 290
metrics_xfail_reason: "Issue-148819"
6 changes: 5 additions & 1 deletion tests/post_training/data/wc_reference_data_2024.5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ tinyllama_scale_estimation_per_channel_backend_OV:
metric_value: 0.80798
num_int4: 188
num_int8: 124
metrics_xfail_reason: "Issue-148819"
metrics_xfail_reason: "Issue-148819"
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
metric_value: 0.87132
num_int4: 11
num_int8: 290
Loading
Loading