From fb70c82726d0abd87942f4931f0fcc4f1110e92e Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Thu, 20 Mar 2025 14:39:34 +0100 Subject: [PATCH 01/11] Minor refactoring --- .../algorithms/weight_compression/gptq.py | 13 ++-- .../weight_compression/mixed_precision.py | 6 +- .../weight_compression/weight_lowering.py | 72 +++++++++---------- 3 files changed, 46 insertions(+), 45 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index ddb5b83b1ae..f6d7bccf5b1 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -28,10 +28,9 @@ from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_quantized_weight -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight from nncf.tensor import Tensor from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType @@ -290,10 +289,14 @@ def _quantize_weights( ) quantized_col = do_nf4_dequantization(compressed_weights, scales[-1], reduction_axis=-1) else: - compressed_weights = calculate_quantized_weight( - fns.unsqueeze(weight_col, 1), block_compression_config, scales[-1], zero_points[-1] + quantized_col, compressed_weights, _, _ = quantize_dequantize_weight( + fns.unsqueeze(weight_col, 1), + block_compression_config, + reduction_axes=None, + precomputed_scale=scales[-1], + precomputed_zero_point=zero_points[-1], + return_compressed_weight=True, ) - quantized_col = do_int_dequantization(compressed_weights, scales[-1], zero_points[-1]) quantized_col = fns.flatten(quantized_col) quantized_block[:, i] = quantized_col loss_block[:, i] = (weight_col - quantized_col) ** 2 / hessian_diag_val**2 diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index c8f5f175d6f..1d9c450d7f0 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -28,9 +28,8 @@ from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error +from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight from nncf.tensor import Tensor from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType @@ -354,8 +353,7 @@ def _calc_weight_sensitivity( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - compressed_weights, scale, zero_point = do_int_quantization(weight, backup_config, reduction_axes) - decompressed_weight = do_int_dequantization(compressed_weights, scale, zero_point) + decompressed_weight = quantize_dequantize_weight(weight, backup_config, reduction_axes) decompressed_weight = decompressed_weight.reshape(orig_shape) return fns.linalg.norm(decompressed_weight - weight, ord="fro").item() diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index cf4f110c745..b775b6081f2 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -286,41 +286,6 @@ def calculate_integer_quantization_params( return scale, None -def calculate_quantized_weight( - weight: Tensor, - config: WeightCompressionConfig, - scale: Tensor, - zero_point: Optional[Tensor] = None, -) -> Tensor: - """ - Quantizes the weight tensor using the provided scale and zero point. - - :param weight: Weight tensor to quantize. - :param config: Weight compression configuration. - :param scale: Scale tensor used for quantization. - :param zero_point: Zero point tensor used for quantization. - :return: Quantized weight tensor of uint8 or int8 type. - """ - if weight.dtype != TensorDataType.float32: - weight = weight.astype(TensorDataType.float32) - if scale.dtype != TensorDataType.float32: - scale = scale.astype(TensorDataType.float32) - - num_bits = config.num_bits - asym_quant = config.is_asym_mode - dtype = TensorDataType.uint8 if asym_quant else TensorDataType.int8 - level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) - level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 - - compressed_weights = weight / scale - if zero_point is not None: - compressed_weights += zero_point.astype(weight.dtype) - compressed_weights = fns.round(compressed_weights) - compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype) - - return compressed_weights - - def get_integer_quantization_error( weight: Tensor, reduction_axes: ReductionAxes, @@ -493,7 +458,7 @@ def do_int_quantization( if precomputed_zero_point is not None: zero_point = precomputed_zero_point - compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point) + compressed_weights = _calculate_quantized_weight(weight, config, scale, zero_point) return compressed_weights, scale, zero_point @@ -542,6 +507,41 @@ def quantize_dequantize_weight( return decompressed_weight +def _calculate_quantized_weight( + weight: Tensor, + config: WeightCompressionConfig, + scale: Tensor, + zero_point: Optional[Tensor] = None, +) -> Tensor: + """ + Quantizes the weight tensor using the provided scale and zero point. + + :param weight: Weight tensor to quantize. + :param config: Weight compression configuration. + :param scale: Scale tensor used for quantization. + :param zero_point: Zero point tensor used for quantization. + :return: Quantized weight tensor of uint8 or int8 type. + """ + if weight.dtype != TensorDataType.float32: + weight = weight.astype(TensorDataType.float32) + if scale.dtype != TensorDataType.float32: + scale = scale.astype(TensorDataType.float32) + + num_bits = config.num_bits + asym_quant = config.is_asym_mode + dtype = TensorDataType.uint8 if asym_quant else TensorDataType.int8 + level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) + level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 + + compressed_weights = weight / scale + if zero_point is not None: + compressed_weights += zero_point.astype(weight.dtype) + compressed_weights = fns.round(compressed_weights) + compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype) + + return compressed_weights + + def _can_run_optimized(input_backend: TensorBackend) -> bool: if input_backend in [TensorBackend.ov, TensorBackend.numpy]: if is_openvino_available(): From dd94ac5ef9dd71c97a71f4fd09679697460861c4 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 21 Mar 2025 10:54:37 +0100 Subject: [PATCH 02/11] Function naming refactor --- nncf/openvino/optimized_functions/__init__.py | 6 ++- .../openvino/optimized_functions/functions.py | 16 ++++---- nncf/openvino/optimized_functions/models.py | 30 +++++++------- .../algorithms/weight_compression/awq.py | 12 +++--- .../algorithms/weight_compression/gptq.py | 12 +++--- .../weight_compression/lora_correction.py | 14 ++++--- .../weight_compression/mixed_precision.py | 4 +- .../weight_compression/scale_estimation.py | 40 +++++++++---------- .../weight_compression/weight_lowering.py | 38 +++++++++--------- nncf/version.py | 2 +- 10 files changed, 89 insertions(+), 85 deletions(-) diff --git a/nncf/openvino/optimized_functions/__init__.py b/nncf/openvino/optimized_functions/__init__.py index f737b984026..0f50f2a41a3 100644 --- a/nncf/openvino/optimized_functions/__init__.py +++ b/nncf/openvino/optimized_functions/__init__.py @@ -10,8 +10,10 @@ # limitations under the License. from nncf.openvino.optimized_functions.functions import astype as astype -from nncf.openvino.optimized_functions.functions import do_int_quantization as do_int_quantization +from nncf.openvino.optimized_functions.functions import do_integer_quantization as do_integer_quantization from nncf.openvino.optimized_functions.functions import get_integer_quantization_error as get_integer_quantization_error -from nncf.openvino.optimized_functions.functions import quantize_dequantize_weight as quantize_dequantize_weight +from nncf.openvino.optimized_functions.functions import ( + integer_quantize_dequantize_weight as integer_quantize_dequantize_weight, +) from nncf.openvino.optimized_functions.models import OVModelParameters as OVModelParameters from nncf.openvino.optimized_functions.models import clear_ov_model_cache as clear_ov_model_cache diff --git a/nncf/openvino/optimized_functions/functions.py b/nncf/openvino/optimized_functions/functions.py index 56851ef6e54..02b390896f6 100644 --- a/nncf/openvino/optimized_functions/functions.py +++ b/nncf/openvino/optimized_functions/functions.py @@ -15,9 +15,9 @@ from nncf.openvino.optimized_functions.models import OV_MODEL_CACHE from nncf.openvino.optimized_functions.models import OVModelParameters from nncf.openvino.optimized_functions.models import get_astype_model -from nncf.openvino.optimized_functions.models import get_compress_decompress_weight_model -from nncf.openvino.optimized_functions.models import get_compress_weight_model -from nncf.openvino.optimized_functions.models import get_quantization_error_model +from nncf.openvino.optimized_functions.models import get_integer_quantization_error_model +from nncf.openvino.optimized_functions.models import get_integer_quantization_model +from nncf.openvino.optimized_functions.models import get_integer_quantize_dequantize_weight_model from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization from nncf.tensor import Tensor @@ -27,7 +27,7 @@ ReductionAxes = Union[int, Tuple[int, ...]] -def do_int_quantization( +def do_integer_quantization( weight: Tensor, config: WeightCompressionConfig, reduction_axes: Optional[ReductionAxes] = None, @@ -63,7 +63,7 @@ def do_int_quantization( {"compressed_weight": compressed_weight_dtype, "zero_point": compressed_weight_dtype} ) - model = get_compress_weight_model( + model = get_integer_quantization_model( ov_model_params, config, weight_shape, @@ -97,7 +97,7 @@ def do_int_quantization( return compressed_weight, scale, zero_point -def quantize_dequantize_weight( +def integer_quantize_dequantize_weight( weight: Tensor, config: WeightCompressionConfig, reduction_axes: Optional[ReductionAxes] = None, @@ -135,7 +135,7 @@ def quantize_dequantize_weight( if precomputed_zero_point is not None: ov_model_params.input_dtypes["zero_point"] = precomputed_zero_point.dtype - model = get_compress_decompress_weight_model( + model = get_integer_quantize_dequantize_weight_model( ov_model_params, config, weight_shape, scale_shape, zero_point_shape, reduction_axes, return_compressed_weight ) @@ -188,7 +188,7 @@ def get_integer_quantization_error( ov_model_params = OVModelParameters() ov_model_params.input_dtypes["weight"] = weight.dtype - model = get_quantization_error_model( + model = get_integer_quantization_error_model( ov_model_params, config, original_weight_shape, weight.shape, original_reduction_axes, reduction_axes ) diff --git a/nncf/openvino/optimized_functions/models.py b/nncf/openvino/optimized_functions/models.py index 6036dd8a438..66fb6034aea 100644 --- a/nncf/openvino/optimized_functions/models.py +++ b/nncf/openvino/optimized_functions/models.py @@ -168,7 +168,7 @@ def _infer_ov_model( return outputs -def _prepare_compression_model_inputs( +def _prepare_quantization_model_inputs( ov_model_params, weight_shape: Tuple, scale_shape: Optional[Tuple], @@ -196,7 +196,7 @@ def _prepare_compression_model_inputs( return weight_shape, scale_shape, zero_point_shape -def get_compress_weight_model( +def get_integer_quantization_model( ov_model_params: OVModelParameters, config: WeightCompressionConfig, weight_shape: Tuple, @@ -219,11 +219,11 @@ def get_compress_weight_model( :return: A model callable that compresses weights using the given configuration. Or a model as nodes, if `return_nodes` is True. """ - weight_shape, scale_shape, zero_point_shape = _prepare_compression_model_inputs( + weight_shape, scale_shape, zero_point_shape = _prepare_quantization_model_inputs( ov_model_params, weight_shape, scale_shape, zero_point_shape, reduction_axes ) - return _build_compress_model( + return _build_integer_quantization_model( config, ov_model_params, weight_shape, @@ -233,7 +233,7 @@ def get_compress_weight_model( ) -def get_compress_decompress_weight_model( +def get_integer_quantize_dequantize_weight_model( ov_model_params: OVModelParameters, config: WeightCompressionConfig, weight_shape: Tuple, @@ -259,11 +259,11 @@ def get_compress_decompress_weight_model( :return: A model callable that returns a decompressed weight, and optionally compressed weight, scale, (and zero point) if `return_compressed_weight` is True. """ - weight_shape, scale_shape, zero_point_shape = _prepare_compression_model_inputs( + weight_shape, scale_shape, zero_point_shape = _prepare_quantization_model_inputs( ov_model_params, weight_shape, scale_shape, zero_point_shape, reduction_axes ) - return _build_compress_decompress_model( + return _build_integer_quantize_dequantize_weight_model( config, ov_model_params, weight_shape, @@ -274,7 +274,7 @@ def get_compress_decompress_weight_model( ) -def get_quantization_error_model( +def get_integer_quantization_error_model( ov_model_params: OVModelParameters, config: WeightCompressionConfig, original_weight_shape: Tuple, @@ -296,15 +296,15 @@ def get_quantization_error_model( :param reduction_axes: Axes to reduce the weight tensor. :return: A model callable that returns the quantization error. """ - weight_shape, _, _ = _prepare_compression_model_inputs(ov_model_params, weight_shape, None, None, reduction_axes) + weight_shape, _, _ = _prepare_quantization_model_inputs(ov_model_params, weight_shape, None, None, reduction_axes) - return _build_quantization_error_model( + return _build_integer_quantization_error_model( config, ov_model_params, original_weight_shape, weight_shape, original_reduction_axes, reduction_axes ) @cache_results(OV_MODEL_CACHE) -def _build_compress_model( +def _build_integer_quantization_model( config: WeightCompressionConfig, ov_model_params: OVModelParameters, weight_shape: Tuple, @@ -454,7 +454,7 @@ def _build_compress_model( @cache_results(OV_MODEL_CACHE) -def _build_compress_decompress_model( +def _build_integer_quantize_dequantize_weight_model( config: WeightCompressionConfig, ov_model_params: OVModelParameters, weight_shape: Tuple, @@ -477,7 +477,7 @@ def _build_compress_decompress_model( raise ValueError(msg) # Get compression model as input/result nodes and potentially modified ov model parameters - ov_parameters, ov_results, ov_model_params = _build_compress_model( + ov_parameters, ov_results, ov_model_params = _build_integer_quantization_model( config, ov_model_params, weight_shape, scale_shape, zero_point_shape, reduction_axes, return_nodes=True ) @@ -514,7 +514,7 @@ def _build_compress_decompress_model( @cache_results(OV_MODEL_CACHE) -def _build_quantization_error_model( +def _build_integer_quantization_error_model( config: WeightCompressionConfig, ov_model_params: OVModelParameters, original_weight_shape: Tuple, @@ -522,7 +522,7 @@ def _build_quantization_error_model( original_reduction_axes: ReductionAxes, reduction_axes: ReductionAxes, ) -> ModelCallable: - ov_parameters, ov_results, ov_model_params = _build_compress_decompress_model( + ov_parameters, ov_results, ov_model_params = _build_integer_quantize_dequantize_weight_model( config, ov_model_params, weight_shape, diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 36052ddf99d..f3b40bc3f15 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -30,10 +30,10 @@ from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.passes import transform_to_inference_graph from nncf.tensor import TensorDataType from nncf.tensor import functions as fns @@ -256,10 +256,10 @@ def apply( weights_to_fake_quantize = gweight * cur_scale if config.mode == CompressWeightsMode.NF4: g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) - g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) - g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) + g_compressed_weighs = calculate_nf4_quantized_weight(weights_to_fake_quantize, g_c_scale) + g_decompressed_weighs = do_float_dequantization(g_compressed_weighs, g_c_scale) else: - g_decompressed_weighs = quantize_dequantize_weight( + g_decompressed_weighs = integer_quantize_dequantize_weight( weights_to_fake_quantize, awq_config, reduction_axis ) sacts = gacts / fns.unsqueeze(cur_scale, 1) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index f6d7bccf5b1..e38bd868306 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -27,10 +27,10 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.tensor import Tensor from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType @@ -284,12 +284,12 @@ def _quantize_weights( zero_points.append(zero_point) if block_compression_config.mode == CompressWeightsMode.NF4: - compressed_weights = do_nf4_quantization( + compressed_weights = calculate_nf4_quantized_weight( fns.unsqueeze(weight_col, 1), scales[-1], is_normalized_weight=False ) - quantized_col = do_nf4_dequantization(compressed_weights, scales[-1], reduction_axis=-1) + quantized_col = do_float_dequantization(compressed_weights, scales[-1], reduction_axis=-1) else: - quantized_col, compressed_weights, _, _ = quantize_dequantize_weight( + quantized_col, compressed_weights, _, _ = integer_quantize_dequantize_weight( fns.unsqueeze(weight_col, 1), block_compression_config, reduction_axes=None, diff --git a/nncf/quantization/algorithms/weight_compression/lora_correction.py b/nncf/quantization/algorithms/weight_compression/lora_correction.py index e7b29981fa4..f456e5c6904 100644 --- a/nncf/quantization/algorithms/weight_compression/lora_correction.py +++ b/nncf/quantization/algorithms/weight_compression/lora_correction.py @@ -25,9 +25,9 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.weight_lowering import CompressedWeight -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_dequantization from nncf.tensor import Tensor from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType @@ -170,15 +170,17 @@ def calculate_low_rank_matrices( assert len(reduction_axes) == 1, "Assumed a single reduction axis" reduction_axis = reduction_axes[0] if compression_config.group_size != -1 else -1 if mode in (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM): - fq_weights = do_int_dequantization( + fq_weights = do_integer_dequantization( compressed_weight.tensor, compressed_weight.scale, compressed_weight.zero_point, reduction_axis, ) elif mode == CompressWeightsMode.NF4: - indexes = do_nf4_quantization(compressed_weight.tensor, compressed_weight.scale, is_normalized_weight=True) - fq_weights = do_nf4_dequantization(indexes, compressed_weight.scale, reduction_axis) + indexes = calculate_nf4_quantized_weight( + compressed_weight.tensor, compressed_weight.scale, is_normalized_weight=True + ) + fq_weights = do_float_dequantization(indexes, compressed_weight.scale, reduction_axis) else: msg = ( f"{mode.value} mode is invalid for Lora Correction algorithm. Supported modes: INT4_SYM, INT4_ASYM, NF4" diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index 1d9c450d7f0..9323f8579b6 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -29,7 +29,7 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error -from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.tensor import Tensor from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType @@ -353,7 +353,7 @@ def _calc_weight_sensitivity( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - decompressed_weight = quantize_dequantize_weight(weight, backup_config, reduction_axes) + decompressed_weight = integer_quantize_dequantize_weight(weight, backup_config, reduction_axes) decompressed_weight = decompressed_weight.reshape(orig_shape) return fns.linalg.norm(decompressed_weight - weight, ord="fro").item() diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 9f3b4c098e8..0720bf6561d 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -23,11 +23,11 @@ from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_normalized_weight_and_fp4_scale -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization from nncf.tensor import Tensor from nncf.tensor import TensorDataType @@ -199,15 +199,13 @@ def calculate_quantization_params( original_weight = fns.zeros_like(weight) + weight if config.mode == CompressWeightsMode.NF4: - norm_weight, scale = calculate_normalized_weight_and_fp4_scale( - original_weight, reduction_axis, cur_config.group_size - ) - compressed_weights = do_nf4_quantization(norm_weight, scale, is_normalized_weight=True) - q_weights = do_nf4_dequantization(compressed_weights, scale, reduction_axis) + norm_weight, scale = do_float_quantization(original_weight, reduction_axis, cur_config.group_size) + compressed_weights = calculate_nf4_quantized_weight(norm_weight, scale, is_normalized_weight=True) + q_weights = do_float_dequantization(compressed_weights, scale, reduction_axis) q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) zp = None else: - q_weights, compressed_weights, scale, zp = quantize_dequantize_weight( + q_weights, compressed_weights, scale, zp = integer_quantize_dequantize_weight( original_weight, cur_config, reduction_axis, return_compressed_weight=True ) if zp is not None: @@ -251,10 +249,10 @@ def calculate_quantization_params( near_to_ideal_scale = near_to_ideal_scale * scale_sign if config.mode == CompressWeightsMode.NF4: - g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale) - out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale) + g_compressed_weighs = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale) + out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale) else: - out = quantize_dequantize_weight( + out = integer_quantize_dequantize_weight( original_weight, config, precomputed_scale=near_to_ideal_scale, @@ -286,9 +284,9 @@ def calculate_quantization_params( if i < initial_steps - 1: if config.mode == CompressWeightsMode.NF4: - out = do_nf4_quantization(original_weight, near_to_ideal_scale) + out = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale) else: - out, _, _ = do_int_quantization( + out, _, _ = do_integer_quantization( original_weight, config, precomputed_scale=near_to_ideal_scale, @@ -304,9 +302,9 @@ def calculate_quantization_params( scaled_scale = factor * scale if config.mode == CompressWeightsMode.NF4: - out = do_nf4_quantization(original_weight, scaled_scale) + out = calculate_nf4_quantized_weight(original_weight, scaled_scale) else: - out, _, _ = do_int_quantization( + out, _, _ = do_integer_quantization( original_weight, config, precomputed_scale=scaled_scale, @@ -320,10 +318,10 @@ def calculate_quantization_params( near_to_ideal_scale = near_to_ideal_scale * scale_sign if config.mode == CompressWeightsMode.NF4: - g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale) - out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale) + g_compressed_weighs = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale) + out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale) else: - out = quantize_dequantize_weight( + out = integer_quantize_dequantize_weight( original_weight, config, precomputed_scale=near_to_ideal_scale, diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index b775b6081f2..83599867167 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -179,7 +179,7 @@ def calculate_normalized_weight(weight: Tensor, scale: Tensor) -> Tensor: return weight / scale -def do_nf4_quantization(weight: Tensor, scale: Tensor, is_normalized_weight: bool = False) -> Tensor: +def calculate_nf4_quantized_weight(weight: Tensor, scale: Tensor, is_normalized_weight: bool = False) -> Tensor: """ Performs NF4 quantization. The floating point values are represented by floating point scale and look-up with 16 floating-point values on [-1, 1]. Scale normalizes original values to [-1, 1] interval and look-up table @@ -198,24 +198,24 @@ def do_nf4_quantization(weight: Tensor, scale: Tensor, is_normalized_weight: boo return nf4_weight -def do_nf4_dequantization(nf4_weight: Tensor, scale: Tensor, reduction_axis: int = -1) -> Tensor: +def do_float_dequantization(compressed_weight: Tensor, scale: Tensor, reduction_axis: int = -1) -> Tensor: """ - Decompresses the NF4 quantized weight tensor. + Decompresses the float-quantized weight tensor. - :param nf4_weight: Tensor with floating-point values, + :param compressed_weight: Tensor with floating-point values, where each of them corresponds to 1 out of 16 quants on [-1, 1]. :param scale: Scale tensor used for decompression. :param reduction_axis: axis along which weights were reshaped for group quantization and will be reshaped back to original shapes. If equals to -1, weights are not reshaped, assumed not a group quantization. Defaults to -1. :return: Decompressed weight tensor. """ - decompressed_weight = nf4_weight * scale + decompressed_weight = compressed_weight * scale if reduction_axis != -1: decompressed_weight = ungroup_weights(decompressed_weight, reduction_axis) return decompressed_weight -def calculate_normalized_weight_and_fp4_scale( +def do_float_quantization( weight: Tensor, reduction_axes: ReductionAxes, group_size: int = -1, @@ -318,7 +318,7 @@ def get_integer_quantization_error( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - decompressed_weight = quantize_dequantize_weight(weight, config, reduction_axes) + decompressed_weight = integer_quantize_dequantize_weight(weight, config, reduction_axes) decompressed_weight = decompressed_weight.reshape(orig_shape) diff = (decompressed_weight - weight) ** 2 @@ -348,11 +348,11 @@ def compress_weight( if weight.backend == TensorBackend.ov: weight = weight.as_numpy_tensor() - compressed_weight, scale = calculate_normalized_weight_and_fp4_scale( + compressed_weight, scale = do_float_quantization( weight, reduction_axes, config.group_size, precomputed_scale, config.mode ) return CompressedWeight(compressed_weight, scale) - compressed_weight, scale, zero_point = do_int_quantization( + compressed_weight, scale, zero_point = do_integer_quantization( weight, config, reduction_axes, precomputed_scale, precomputed_zero_point ) @@ -377,7 +377,7 @@ def ungroup_weights(weights: Tensor, reduction_axis: int) -> Tensor: return weights -def do_int_dequantization( +def do_integer_dequantization( compressed_weights: Tensor, scale: Tensor, zero_point: Optional[Tensor] = None, reduction_axis: int = -1 ) -> Tensor: """ @@ -402,7 +402,7 @@ def do_int_dequantization( return decompressed_weight -def do_int_quantization( +def do_integer_quantization( weight: Tensor, config: WeightCompressionConfig, reduction_axes: Optional[ReductionAxes] = None, @@ -437,7 +437,7 @@ def do_int_quantization( # Optimized implementation if _can_run_optimized(weight.backend): - from nncf.openvino.optimized_functions import do_int_quantization as do_int_quantization_ov + from nncf.openvino.optimized_functions import do_integer_quantization as do_int_quantization_ov return do_int_quantization_ov(weight, config, reduction_axes, precomputed_scale, precomputed_zero_point) @@ -458,11 +458,11 @@ def do_int_quantization( if precomputed_zero_point is not None: zero_point = precomputed_zero_point - compressed_weights = _calculate_quantized_weight(weight, config, scale, zero_point) + compressed_weights = _calculate_integer_quantized_weight(weight, config, scale, zero_point) return compressed_weights, scale, zero_point -def quantize_dequantize_weight( +def integer_quantize_dequantize_weight( weight: Tensor, config: WeightCompressionConfig, reduction_axes: Optional[ReductionAxes] = None, @@ -485,7 +485,9 @@ def quantize_dequantize_weight( """ # Optimized implementation if _can_run_optimized(weight.backend): - from nncf.openvino.optimized_functions import quantize_dequantize_weight as quantize_dequantize_weight_ov + from nncf.openvino.optimized_functions import ( + integer_quantize_dequantize_weight as quantize_dequantize_weight_ov, + ) return quantize_dequantize_weight_ov( weight, @@ -497,17 +499,17 @@ def quantize_dequantize_weight( ) # Reference implementation - compressed_weight, scale, zero_point = do_int_quantization( + compressed_weight, scale, zero_point = do_integer_quantization( weight, config, reduction_axes, precomputed_scale, precomputed_zero_point ) - decompressed_weight = do_int_dequantization(compressed_weight, scale, zero_point) + decompressed_weight = do_integer_dequantization(compressed_weight, scale, zero_point) if return_compressed_weight: return decompressed_weight, compressed_weight, scale, zero_point else: return decompressed_weight -def _calculate_quantized_weight( +def _calculate_integer_quantized_weight( weight: Tensor, config: WeightCompressionConfig, scale: Tensor, diff --git a/nncf/version.py b/nncf/version.py index f7b5b2206e3..95a9f571d48 100644 --- a/nncf/version.py +++ b/nncf/version.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.16.0" +__version__ = "2.16.0.dev0+fb70c8272dirty" BKC_TORCH_SPEC = "==2.6.*" From 7dc55a247193b58c0344a58c7ebc1cee94b757f8 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 21 Mar 2025 18:39:11 +0100 Subject: [PATCH 03/11] Added optimized nf4 quantization --- nncf/openvino/optimized_functions/__init__.py | 1 + .../openvino/optimized_functions/functions.py | 44 ++++ nncf/openvino/optimized_functions/models.py | 112 ++++++++++ .../algorithms/weight_compression/awq.py | 8 +- .../algorithms/weight_compression/gptq.py | 12 +- .../weight_compression/lora_correction.py | 6 +- .../weight_compression/scale_estimation.py | 16 +- .../weight_compression/weight_lowering.py | 204 +++++++++--------- nncf/tensor/functions/__init__.py | 1 + nncf/tensor/functions/numeric.py | 10 + nncf/tensor/functions/numpy_numeric.py | 5 + nncf/tensor/functions/openvino_numeric.py | 2 +- nncf/tensor/functions/torch_numeric.py | 5 + .../template_test_weights_compression.py | 4 +- .../quantization/test_weights_compression.py | 10 +- .../test_compression_functions.py | 59 +++-- .../test_ov_model_parameters.py | 22 +- 17 files changed, 365 insertions(+), 156 deletions(-) diff --git a/nncf/openvino/optimized_functions/__init__.py b/nncf/openvino/optimized_functions/__init__.py index 0f50f2a41a3..82117571732 100644 --- a/nncf/openvino/optimized_functions/__init__.py +++ b/nncf/openvino/optimized_functions/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from nncf.openvino.optimized_functions.functions import astype as astype +from nncf.openvino.optimized_functions.functions import do_float_quantization as do_float_quantization from nncf.openvino.optimized_functions.functions import do_integer_quantization as do_integer_quantization from nncf.openvino.optimized_functions.functions import get_integer_quantization_error as get_integer_quantization_error from nncf.openvino.optimized_functions.functions import ( diff --git a/nncf/openvino/optimized_functions/functions.py b/nncf/openvino/optimized_functions/functions.py index 02b390896f6..c7a65e67655 100644 --- a/nncf/openvino/optimized_functions/functions.py +++ b/nncf/openvino/optimized_functions/functions.py @@ -11,10 +11,12 @@ from typing import Optional, Tuple, Union +from nncf import CompressWeightsMode from nncf.common.utils.caching import disable_results_caching from nncf.openvino.optimized_functions.models import OV_MODEL_CACHE from nncf.openvino.optimized_functions.models import OVModelParameters from nncf.openvino.optimized_functions.models import get_astype_model +from nncf.openvino.optimized_functions.models import get_float_quantization_model from nncf.openvino.optimized_functions.models import get_integer_quantization_error_model from nncf.openvino.optimized_functions.models import get_integer_quantization_model from nncf.openvino.optimized_functions.models import get_integer_quantize_dequantize_weight_model @@ -97,6 +99,48 @@ def do_integer_quantization( return compressed_weight, scale, zero_point +def do_float_quantization( + weight: Tensor, + config: WeightCompressionConfig, + reduction_axes: Optional[ReductionAxes] = None, + precomputed_scale: Tensor = None, +) -> Tuple[Tensor, Tensor]: + weight_shape = weight.shape + scale_shape = None if precomputed_scale is None else precomputed_scale.shape + + ov_model_params = OVModelParameters() + ov_model_params.input_dtypes["weight"] = weight.dtype + if precomputed_scale is not None: + ov_model_params.input_dtypes["scale"] = precomputed_scale.dtype + if config.num_bits == 4 and weight.backend == TensorBackend.ov: + # Return ov tensors in target precision to seamlessly insert them into openvino model later + ov_model_params.return_ov_tensors = True + dtype = TensorDataType.nf4 if config.mode == CompressWeightsMode.NF4 else TensorDataType.f4e2m1 + ov_model_params.output_dtypes.update({"compressed_weight": dtype}) + + model = get_float_quantization_model( + ov_model_params, + config, + weight_shape, + scale_shape, + reduction_axes, + ) + + if precomputed_scale is None: + # weight -> compressed_weight, scale + compressed_weight, scale = model([weight]) + + # Scale is always in fp32 so there is no need to store it in ov.Tensor + if scale.backend == TensorBackend.ov: + scale = scale.as_numpy_tensor() + else: + # weight, scale -> compressed_weight + compressed_weight = model([weight, precomputed_scale])[0] + scale = precomputed_scale + + return compressed_weight, scale + + def integer_quantize_dequantize_weight( weight: Tensor, config: WeightCompressionConfig, diff --git a/nncf/openvino/optimized_functions/models.py b/nncf/openvino/optimized_functions/models.py index 66fb6034aea..df4a95eb443 100644 --- a/nncf/openvino/optimized_functions/models.py +++ b/nncf/openvino/optimized_functions/models.py @@ -22,6 +22,7 @@ from openvino.runtime import Node from openvino.runtime import opset13 as opset +from nncf import CompressWeightsMode from nncf.common.utils.backend import is_openvino_at_least from nncf.common.utils.caching import ResultsCache from nncf.common.utils.caching import cache_results @@ -233,6 +234,26 @@ def get_integer_quantization_model( ) +def get_float_quantization_model( + ov_model_params: OVModelParameters, + config: WeightCompressionConfig, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, +) -> Union[ModelCallable, ModelAsNodes]: + weight_shape, scale_shape, _ = _prepare_quantization_model_inputs( + ov_model_params, weight_shape, scale_shape, zero_point_shape=None, reduction_axes=reduction_axes + ) + + return _build_float_quantization_model( + config, + ov_model_params, + weight_shape, + scale_shape, + reduction_axes, + ) + + def get_integer_quantize_dequantize_weight_model( ov_model_params: OVModelParameters, config: WeightCompressionConfig, @@ -453,6 +474,97 @@ def _build_integer_quantization_model( return partial(_infer_ov_model, ov_model_params, compiled_model) +@cache_results(OV_MODEL_CACHE) +def _build_float_quantization_model( + config: WeightCompressionConfig, + ov_model_params: OVModelParameters, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, + return_nodes: bool = False, +) -> Union[ModelCallable, ModelAsNodes]: + default_input_dtypes = {"scale": TensorDataType.float32} + default_output_dtypes = {"compressed_weight": TensorDataType.float32, "scale": TensorDataType.float32} + + # Update input and output dtypes with the default values + ov_model_params = copy.deepcopy(ov_model_params) + ov_model_params.input_dtypes = {**default_input_dtypes, **ov_model_params.input_dtypes} + ov_model_params.output_dtypes = {**default_output_dtypes, **ov_model_params.output_dtypes} + + if "weight" not in ov_model_params.input_dtypes: + msg = "Input weight dtype is required!" + raise ValueError(msg) + + weight_dtype = ov_model_params.input_dtypes["weight"] + input_scale_dtype = ov_model_params.input_dtypes["scale"] + compressed_weight_dtype = ov_model_params.output_dtypes["compressed_weight"] + output_scale_dtype = ov_model_params.output_dtypes["scale"] + + # Validate input dtypes + valid_weight_dtypes = [TensorDataType.float32, TensorDataType.float16, TensorDataType.bfloat16] + if weight_dtype not in valid_weight_dtypes: + msg = f"Weight must be one of the following data types: {valid_weight_dtypes}. But found: {weight_dtype}." + raise ValueError(msg) + if scale_shape is not None and input_scale_dtype != TensorDataType.float32: + msg = f"Input scale must be of float32 data type. But found: {input_scale_dtype}." + raise ValueError(msg) + + # Validate output dtypes + # TODO: Enable f4e2m1 + valid_compressed_weight_dtypes = [TensorDataType.float32, TensorDataType.nf4] + if compressed_weight_dtype not in valid_compressed_weight_dtypes: + msg = ( + f"Compressed weight must be one of the following data types: {valid_compressed_weight_dtypes}. " + f"But found: {compressed_weight_dtype}." + ) + raise ValueError(msg) + if scale_shape is None and output_scale_dtype != TensorDataType.float32: + msg = f"Output scale must be of float32 data type. But found: {output_scale_dtype}." + raise ValueError(msg) + + # Build OV model + weight = opset.parameter(weight_shape, name="weight", dtype=DTYPE_MAP_OV[weight_dtype]) + ov_parameters = [weight] + weight = convert_op(weight, ov.Type.f32) + + divide_op = opset.divide if ov_model_params.convertable_division else non_convertable_divide_op + if scale_shape is not None: + # Scale is given as an input + scale = opset.parameter(scale_shape, name="scale", dtype=DTYPE_MAP_OV[input_scale_dtype]) + ov_parameters.append(scale) + else: + # Compute scale + scale = opset.reduce_max(opset.abs(weight), reduction_axes=reduction_axes, keep_dims=True) + # NOTE: adding machine epsilon to avoid division by zero + eps = np.finfo(np.float32).eps + scale = opset.select(opset.less(opset.abs(scale), eps), eps, scale) + + if config.mode == CompressWeightsMode.E2M1: + max_val = opset.constant(6, ov.Type.f32) # Maximal value of e2m1 type. + constant_2 = opset.constant(2, ov.Type.f32) + scale = divide_op(scale, max_val) + scale = opset.log(scale) / opset.log(constant_2) + scale = opset.ceil(scale) + scale = opset.clamp(scale, -127, 127) + scale = opset.power(constant_2, scale) + + compressed_weight = divide_op(weight, scale) + compressed_weight = convert_op(compressed_weight, ov.Type.nf4) + compressed_weight = convert_op(compressed_weight, DTYPE_MAP_OV[compressed_weight_dtype]) + + ov_results = [compressed_weight] + if len(ov_parameters) == 1: + ov_results.append(scale) + + if return_nodes: + return ov_parameters, ov_results, ov_model_params + + model = ov.Model(ov_results, ov_parameters) + compiled_model = _compile_ov_model(model, device_name="CPU", config={inference_precision(): ov.Type.f32}) + + return partial(_infer_ov_model, ov_model_params, compiled_model) + + @cache_results(OV_MODEL_CACHE) def _build_integer_quantize_dequantize_weight_model( config: WeightCompressionConfig, diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index f3b40bc3f15..8c8b5bbb6b0 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -30,9 +30,8 @@ from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.passes import transform_to_inference_graph from nncf.tensor import TensorDataType @@ -255,8 +254,9 @@ def apply( cur_scale = gscale**alpha weights_to_fake_quantize = gweight * cur_scale if config.mode == CompressWeightsMode.NF4: - g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) - g_compressed_weighs = calculate_nf4_quantized_weight(weights_to_fake_quantize, g_c_scale) + g_compressed_weighs, g_c_scale = do_float_quantization( + weights_to_fake_quantize, config, reduction_axis + ) g_decompressed_weighs = do_float_dequantization(g_compressed_weighs, g_c_scale) else: g_decompressed_weighs = integer_quantize_dequantize_weight( diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index e38bd868306..48bf458ce2f 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -26,10 +26,10 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -262,7 +262,9 @@ def _quantize_weights( if (i1 + i) % group_size == 0: if block_compression_config.mode == CompressWeightsMode.NF4: - scale = calculate_nf4_scale(weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes) + scale = calculate_float_quantization_params( + weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, block_compression_config + ) scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: @@ -284,8 +286,8 @@ def _quantize_weights( zero_points.append(zero_point) if block_compression_config.mode == CompressWeightsMode.NF4: - compressed_weights = calculate_nf4_quantized_weight( - fns.unsqueeze(weight_col, 1), scales[-1], is_normalized_weight=False + compressed_weights, _ = do_float_quantization( + fns.unsqueeze(weight_col, 1), block_compression_config, precomputed_scale=scales[-1] ) quantized_col = do_float_dequantization(compressed_weights, scales[-1], reduction_axis=-1) else: diff --git a/nncf/quantization/algorithms/weight_compression/lora_correction.py b/nncf/quantization/algorithms/weight_compression/lora_correction.py index f456e5c6904..724e9e2b3df 100644 --- a/nncf/quantization/algorithms/weight_compression/lora_correction.py +++ b/nncf/quantization/algorithms/weight_compression/lora_correction.py @@ -25,7 +25,6 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.weight_lowering import CompressedWeight -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_dequantization from nncf.tensor import Tensor @@ -177,10 +176,7 @@ def calculate_low_rank_matrices( reduction_axis, ) elif mode == CompressWeightsMode.NF4: - indexes = calculate_nf4_quantized_weight( - compressed_weight.tensor, compressed_weight.scale, is_normalized_weight=True - ) - fq_weights = do_float_dequantization(indexes, compressed_weight.scale, reduction_axis) + fq_weights = do_float_dequantization(compressed_weight.tensor, compressed_weight.scale, reduction_axis) else: msg = ( f"{mode.value} mode is invalid for Lora Correction algorithm. Supported modes: INT4_SYM, INT4_ASYM, NF4" diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 0720bf6561d..328e917c7c1 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -23,7 +23,6 @@ from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters -from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization @@ -199,8 +198,7 @@ def calculate_quantization_params( original_weight = fns.zeros_like(weight) + weight if config.mode == CompressWeightsMode.NF4: - norm_weight, scale = do_float_quantization(original_weight, reduction_axis, cur_config.group_size) - compressed_weights = calculate_nf4_quantized_weight(norm_weight, scale, is_normalized_weight=True) + compressed_weights, scale = do_float_quantization(original_weight, cur_config, reduction_axis) q_weights = do_float_dequantization(compressed_weights, scale, reduction_axis) q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) zp = None @@ -249,7 +247,9 @@ def calculate_quantization_params( near_to_ideal_scale = near_to_ideal_scale * scale_sign if config.mode == CompressWeightsMode.NF4: - g_compressed_weighs = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale) + g_compressed_weighs, _ = do_float_quantization( + original_weight, config, precomputed_scale=near_to_ideal_scale + ) out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale) else: out = integer_quantize_dequantize_weight( @@ -284,7 +284,7 @@ def calculate_quantization_params( if i < initial_steps - 1: if config.mode == CompressWeightsMode.NF4: - out = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale) + out, _ = do_float_quantization(original_weight, config, precomputed_scale=near_to_ideal_scale) else: out, _, _ = do_integer_quantization( original_weight, @@ -302,7 +302,7 @@ def calculate_quantization_params( scaled_scale = factor * scale if config.mode == CompressWeightsMode.NF4: - out = calculate_nf4_quantized_weight(original_weight, scaled_scale) + out, _ = do_float_quantization(original_weight, config, precomputed_scale=scaled_scale) else: out, _, _ = do_integer_quantization( original_weight, @@ -318,7 +318,9 @@ def calculate_quantization_params( near_to_ideal_scale = near_to_ideal_scale * scale_sign if config.mode == CompressWeightsMode.NF4: - g_compressed_weighs = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale) + g_compressed_weighs, _ = do_float_quantization( + original_weight, config, precomputed_scale=near_to_ideal_scale + ) out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale) else: out = integer_quantize_dequantize_weight( diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 83599867167..65ee971e474 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -100,14 +100,19 @@ def reshape_weight_for_grouped_quantization( return reshaped_weight, reduction_axes -def calculate_nf4_scale(weight: Tensor, reduction_axes: ReductionAxes) -> Tensor: +def calculate_float_quantization_params( + weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig, max_val=6.0 +) -> Tensor: """ Calculates the scale for nf4 quantization. :param weight: Weight array to compress. :param reduction_axes: Axes along which to reduce (collect) different statistics (e.g., min, max). + :param max_val: Maximal value of e2m1 type. :return: Scale tensor of float32 type for nf4 quantization. """ + assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] + if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) @@ -117,87 +122,16 @@ def calculate_nf4_scale(weight: Tensor, reduction_axes: ReductionAxes) -> Tensor eps = fns.finfo(weight).eps scale = fns.where(fns.abs(scale) < eps, eps, scale) - return scale - - -def calculate_e2m1_scale(weight: Tensor, reduction_axes: ReductionAxes, max_val=6.0) -> Tensor: - """ - Calculates the scale for e2m1 quantization. - - :param weight: Weight array to compress. - :param reduction_axes: Axes along which to reduce (collect) different statistics (e.g., min, max). - :param max_val: Maximal value of e2m1 type. - :param to_e8m0: Defines convert scale to e8m0 or not. - :return: Scale tensor of float32 type for e2m1 quantization. - """ - scale = calculate_nf4_scale(weight, reduction_axes) / max_val - - scale = fns.log2(scale) - scale = fns.ceil(scale) - scale = fns.clip(scale, -127, 127) - scale = 2**scale - - return scale - - -def calculate_signed_scale(weight: Tensor, reduction_axes: ReductionAxes, num_bits=4) -> Tensor: - """ - Calculates the signed scale for symmetric quantization. - - :param weight: Weight array to compress. - :param reduction_axes: Axes along which to reduce (collect) different statistics (e.g., min, max). - :param num_bits: number of bits in compression. - :return: Scale tensor. - """ - factor = 2 ** (num_bits - 1) - - w_abs_min = fns.abs(fns.min(weight, axis=reduction_axes, keepdims=True)) - w_max = fns.max(weight, axis=reduction_axes, keepdims=True) - - scale = fns.where(w_abs_min >= w_max, w_abs_min, -w_max) - scale /= factor - - eps = fns.finfo(scale).eps - scale = fns.where(fns.abs(scale) < eps, eps, scale) + if config.mode == CompressWeightsMode.E2M1: + scale = scale / max_val + scale = fns.log2(scale) + scale = fns.ceil(scale) + scale = fns.clip(scale, -127, 127) + scale = 2**scale return scale -def calculate_normalized_weight(weight: Tensor, scale: Tensor) -> Tensor: - """ - Normalizes the weight tensor using the provided scale. - - :param weight: Weight tensor to normalize. - :param scale: Scale tensor used for normalization. - :return: Normalized weight tensor. - """ - if weight.dtype != TensorDataType.float32: - weight = weight.astype(TensorDataType.float32) - if scale.dtype != TensorDataType.float32: - scale = scale.astype(TensorDataType.float32) - - return weight / scale - - -def calculate_nf4_quantized_weight(weight: Tensor, scale: Tensor, is_normalized_weight: bool = False) -> Tensor: - """ - Performs NF4 quantization. The floating point values are represented by floating point scale and look-up with - 16 floating-point values on [-1, 1]. Scale normalizes original values to [-1, 1] interval and look-up table - "rounds" or "quantize" to the closest quant. - - :param weight: Weight tensor to quantize. - :param scale: Scale tensor used for normalization. - :param is_normalized_weight: Whether weight was scaled to [-1, 1] interval. Defaults to False. - :return: Tensor with floating-point values, where each of them corresponds to 1 out of 16 quants on [-1, 1]. - """ - norm_weight = weight if is_normalized_weight else calculate_normalized_weight(weight, scale) - center_nf4_quantiles = fns.from_numpy(CENTER_OF_NF4_QUANTILES, backend=norm_weight.backend) - indexes = fns.searchsorted(center_nf4_quantiles, norm_weight) - nf4_quantiles = fns.from_numpy(NF4_QUANTILES, backend=indexes.backend) - nf4_weight = nf4_quantiles[indexes] - return nf4_weight - - def do_float_dequantization(compressed_weight: Tensor, scale: Tensor, reduction_axis: int = -1) -> Tensor: """ Decompresses the float-quantized weight tensor. @@ -217,10 +151,9 @@ def do_float_dequantization(compressed_weight: Tensor, scale: Tensor, reduction_ def do_float_quantization( weight: Tensor, - reduction_axes: ReductionAxes, - group_size: int = -1, + config: WeightCompressionConfig, + reduction_axes: Optional[ReductionAxes] = None, precomputed_scale: Tensor = None, - mode: CompressWeightsMode = CompressWeightsMode.NF4, ) -> Tuple[Tensor, Tensor]: """ Calculates scale for fp4 (nf4, e2m1) quantization and normalizes weights by the scale. @@ -233,20 +166,35 @@ def do_float_quantization( :param precomputed_scale: Precomputed scale. :return: Normalized weight tensor of float32 type and nf4 scale tensor of float32 type. """ - assert mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] - if weight.dtype != TensorDataType.float32: - weight = weight.astype(TensorDataType.float32) + assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] - if group_size != -1: + if config.group_size != -1 and reduction_axes is not None: # weights are reshaped: [a1, r, a2] -> [a1, r//gs, gs, a2] - weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, group_size) + weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size) + + # Optimized implementation + if config.mode == CompressWeightsMode.NF4 and _can_run_optimized(weight.backend): + from nncf.openvino.optimized_functions import do_float_quantization as do_float_quantization_ov + + return do_float_quantization_ov(weight, config, reduction_axes, precomputed_scale) + + if weight.backend == TensorBackend.ov: + weight = weight.as_numpy_tensor() + if weight.dtype != TensorDataType.float32: + weight = weight.astype(TensorDataType.float32) - if mode == CompressWeightsMode.NF4: - scale = calculate_nf4_scale(weight, reduction_axes) if precomputed_scale is None else precomputed_scale - if mode == CompressWeightsMode.E2M1: - scale = calculate_e2m1_scale(weight, reduction_axes) if precomputed_scale is None else precomputed_scale - norm_weight = calculate_normalized_weight(weight, scale) - return norm_weight, scale + scale = ( + calculate_float_quantization_params(weight, reduction_axes, config) + if precomputed_scale is None + else precomputed_scale + ) + norm_weight = _calculate_normalized_weight(weight, scale) + if config.mode == CompressWeightsMode.NF4: + compressed_weight = _calculate_nf4_quantized_weight(norm_weight, scale, config.mode, is_normalized_weight=True) + else: + # TODO: Implement proper quantization for E2M1 + compressed_weight = norm_weight + return compressed_weight, scale def calculate_integer_quantization_params( @@ -282,7 +230,7 @@ def calculate_integer_quantization_params( ) return scale, zero_point - scale = calculate_signed_scale(weight, reduction_axes, num_bits) + scale = _calculate_signed_scale(weight, reduction_axes, num_bits) return scale, None @@ -345,17 +293,11 @@ def compress_weight( :return: The compressed weight and decompression parameters as instance of CompressedWeight """ if not config.is_integer: - if weight.backend == TensorBackend.ov: - weight = weight.as_numpy_tensor() - - compressed_weight, scale = do_float_quantization( - weight, reduction_axes, config.group_size, precomputed_scale, config.mode - ) + compressed_weight, scale = do_float_quantization(weight, config, reduction_axes, precomputed_scale) return CompressedWeight(compressed_weight, scale) compressed_weight, scale, zero_point = do_integer_quantization( weight, config, reduction_axes, precomputed_scale, precomputed_zero_point ) - return CompressedWeight(compressed_weight, scale, zero_point) @@ -509,6 +451,68 @@ def integer_quantize_dequantize_weight( return decompressed_weight +def _calculate_nf4_quantized_weight( + weight: Tensor, scale: Tensor, mode: CompressWeightsMode, is_normalized_weight: bool = False +) -> Tensor: + """ + Performs NF4 quantization. The floating point values are represented by floating point scale and look-up with + 16 floating-point values on [-1, 1]. Scale normalizes original values to [-1, 1] interval and look-up table + "rounds" or "quantize" to the closest quant. + + :param weight: Weight tensor to quantize. + :param scale: Scale tensor used for normalization. + :param is_normalized_weight: Whether weight was scaled to [-1, 1] interval. Defaults to False. + :return: Tensor with floating-point values, where each of them corresponds to 1 out of 16 quants on [-1, 1]. + """ + assert mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] + + norm_weight = weight if is_normalized_weight else _calculate_normalized_weight(weight, scale) + center_nf4_quantiles = fns.from_numpy(CENTER_OF_NF4_QUANTILES, backend=norm_weight.backend) + indexes = fns.searchsorted(center_nf4_quantiles, norm_weight) + nf4_quantiles = fns.from_numpy(NF4_QUANTILES, backend=indexes.backend) + quantized_weight = nf4_quantiles[indexes] + return quantized_weight + + +def _calculate_normalized_weight(weight: Tensor, scale: Tensor) -> Tensor: + """ + Normalizes the weight tensor using the provided scale. + + :param weight: Weight tensor to normalize. + :param scale: Scale tensor used for normalization. + :return: Normalized weight tensor. + """ + if weight.dtype != TensorDataType.float32: + weight = weight.astype(TensorDataType.float32) + if scale.dtype != TensorDataType.float32: + scale = scale.astype(TensorDataType.float32) + + return weight / scale + + +def _calculate_signed_scale(weight: Tensor, reduction_axes: ReductionAxes, num_bits=4) -> Tensor: + """ + Calculates the signed scale for symmetric quantization. + + :param weight: Weight array to compress. + :param reduction_axes: Axes along which to reduce (collect) different statistics (e.g., min, max). + :param num_bits: number of bits in compression. + :return: Scale tensor. + """ + factor = 2 ** (num_bits - 1) + + w_abs_min = fns.abs(fns.min(weight, axis=reduction_axes, keepdims=True)) + w_max = fns.max(weight, axis=reduction_axes, keepdims=True) + + scale = fns.where(w_abs_min >= w_max, w_abs_min, -w_max) + scale /= factor + + eps = fns.finfo(scale).eps + scale = fns.where(fns.abs(scale) < eps, eps, scale) + + return scale + + def _calculate_integer_quantized_weight( weight: Tensor, config: WeightCompressionConfig, diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index 568a4444ffc..050f5a39ef1 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -54,6 +54,7 @@ from nncf.tensor.functions.numeric import reshape as reshape from nncf.tensor.functions.numeric import round as round from nncf.tensor.functions.numeric import searchsorted as searchsorted +from nncf.tensor.functions.numeric import sign as sign from nncf.tensor.functions.numeric import squeeze as squeeze from nncf.tensor.functions.numeric import stack as stack from nncf.tensor.functions.numeric import sum as sum diff --git a/nncf/tensor/functions/numeric.py b/nncf/tensor/functions/numeric.py index d1dc51577b5..3eb3941221c 100644 --- a/nncf/tensor/functions/numeric.py +++ b/nncf/tensor/functions/numeric.py @@ -105,6 +105,16 @@ def abs(a: Tensor) -> Tensor: """ +@tensor_dispatcher +def sign(a: Tensor) -> Tensor: + """ + Calculate the sign value element-wise. + + :param a: The input tensor. + :return: A tensor containing the sign value of each element in x. + """ + + @tensor_dispatcher def astype(a: Tensor, dtype: TensorDataType) -> Tensor: """ diff --git a/nncf/tensor/functions/numpy_numeric.py b/nncf/tensor/functions/numpy_numeric.py index f52f4a69470..38b33d8d124 100644 --- a/nncf/tensor/functions/numpy_numeric.py +++ b/nncf/tensor/functions/numpy_numeric.py @@ -87,6 +87,11 @@ def _(a: T_NUMPY) -> T_NUMPY: return np.absolute(a) +@numeric.sign.register +def _(a: T_NUMPY) -> T_NUMPY: + return np.sign(a) + + @numeric.astype.register def _(a: T_NUMPY, dtype: TensorDataType) -> T_NUMPY: return a.astype(DTYPE_MAP[dtype]) diff --git a/nncf/tensor/functions/openvino_numeric.py b/nncf/tensor/functions/openvino_numeric.py index e6eedc3f13a..386ec1a1093 100644 --- a/nncf/tensor/functions/openvino_numeric.py +++ b/nncf/tensor/functions/openvino_numeric.py @@ -83,7 +83,7 @@ def _(a: ov.Tensor, shape: Union[int, Tuple[int, ...]]) -> ov.Tensor: @numeric.as_numpy_tensor.register def _(a: ov.Tensor) -> NDArray[Any]: - # Cannot convert bfloat16, uint4, int4, nf4, f8e4m3, f8e5m2 to numpy directly + # Cannot convert bfloat16, uint4, int4, nf4, f4e2m1, f8e4m3, f8e5m2 to numpy directly a_dtype = DTYPE_MAP_REV[a.get_element_type()] if a_dtype in [ TensorDataType.bfloat16, diff --git a/nncf/tensor/functions/torch_numeric.py b/nncf/tensor/functions/torch_numeric.py index 905fe916014..3c990e7adad 100644 --- a/nncf/tensor/functions/torch_numeric.py +++ b/nncf/tensor/functions/torch_numeric.py @@ -98,6 +98,11 @@ def _(a: torch.Tensor) -> torch.Tensor: return torch.absolute(a) +@numeric.abs.register +def _(a: torch.Tensor) -> torch.Tensor: + return torch.sign(a) + + @numeric.astype.register def _(a: torch.Tensor, dtype: TensorDataType) -> torch.Tensor: return a.type(DTYPE_MAP[dtype]) diff --git a/tests/cross_fw/test_templates/template_test_weights_compression.py b/tests/cross_fw/test_templates/template_test_weights_compression.py index d595ed8c6fc..1f3ff241adf 100644 --- a/tests/cross_fw/test_templates/template_test_weights_compression.py +++ b/tests/cross_fw/test_templates/template_test_weights_compression.py @@ -25,7 +25,7 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation -from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.scopes import IgnoredScope from nncf.tensor import Tensor from nncf.tensor import TensorDataType @@ -220,7 +220,7 @@ def test_scale_estimation_outlier_channel_has_lowest_error(self): dataset=dataset, ) - decompressed_weight_before_se = quantize_dequantize_weight( + decompressed_weight_before_se = integer_quantize_dequantize_weight( original_weight, config=WeightCompressionConfig(CompressWeightsMode.INT4_ASYM, -1), reduction_axes=1 ) decompressed_weight_after_se = self.get_decompressed_weight(compressed_model, input) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index c687d246bf1..f398cb481fe 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -40,7 +40,7 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization from nncf.scopes import IgnoredScope @@ -1041,7 +1041,7 @@ def test_compressed_weighs_range(mode, data): w = Tensor(data) config = WeightCompressionConfig(mode=mode) - compressed_weighs, _, _ = do_int_quantization(w, config, -1) + compressed_weighs, _, _ = do_integer_quantization(w, config, -1) assert np.allclose(np.abs(compressed_weighs.data), np.abs(w.data)) @@ -1073,14 +1073,16 @@ def test_int_quantization_with_precomputed_parameters(config, precompute_scale, if raises: with pytest.raises(ValueError) as exc_info: - _, scale, zero_point = do_int_quantization(weight, config, -1, precomputed_scale, precomputed_zero_point) + _, scale, zero_point = do_integer_quantization( + weight, config, -1, precomputed_scale, precomputed_zero_point + ) assert exc_info.value == ( "If precomputed quantization parameters are provided, both scale and zero point " "are required for asymmetric quantization." ) return else: - _, scale, zero_point = do_int_quantization(weight, config, -1, precomputed_scale, precomputed_zero_point) + _, scale, zero_point = do_integer_quantization(weight, config, -1, precomputed_scale, precomputed_zero_point) if precompute_scale: assert np.allclose(scale.data, precomputed_scale.data) diff --git a/tests/openvino/optimized_functions/test_compression_functions.py b/tests/openvino/optimized_functions/test_compression_functions.py index 63950921024..f047b6dccea 100644 --- a/tests/openvino/optimized_functions/test_compression_functions.py +++ b/tests/openvino/optimized_functions/test_compression_functions.py @@ -23,9 +23,10 @@ from nncf.common.utils.caching import ResultsCache from nncf.common.utils.caching import cache_results from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error -from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization from nncf.tensor import Tensor from nncf.tensor import TensorDataType @@ -58,7 +59,13 @@ class QuantizationTask(Enum): WeightCompressionConfig(CompressWeightsMode.INT4_SYM, group_size=2), ] -COMPRESSION_CONFIGS = INT8_COMPRESSION_CONFIGS + INT4_COMPRESSION_CONFIGS +FP4_COMPRESSION_CONFIGS = [ + WeightCompressionConfig(CompressWeightsMode.NF4), + WeightCompressionConfig(CompressWeightsMode.NF4, group_size=2), +] + +COMPRESSION_CONFIGS = INT8_COMPRESSION_CONFIGS + INT4_COMPRESSION_CONFIGS + FP4_COMPRESSION_CONFIGS +# COMPRESSION_CONFIGS = FP4_COMPRESSION_CONFIGS WEIGHT_SHAPE = (10000, 4) @@ -152,11 +159,17 @@ def test_quantization_alignment(weight_shape, config, quantization_task, tensor_ ) if quantization_task == QuantizationTask.Q: - fn_to_call = do_int_quantization - fn_to_patch = opt_fns.do_int_quantization + if config.is_integer: + fn_to_call = do_integer_quantization + fn_to_patch = opt_fns.do_integer_quantization + else: + fn_to_call = do_float_quantization + fn_to_patch = opt_fns.do_float_quantization else: - fn_to_call = quantize_dequantize_weight - fn_to_patch = opt_fns.quantize_dequantize_weight + if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: + pytest.skip("Quantize-dequantize is not supported for NF4 and E2M1 modes") + fn_to_call = integer_quantize_dequantize_weight + fn_to_patch = opt_fns.integer_quantize_dequantize_weight patch_path = f"nncf.openvino.optimized_functions.{fn_to_patch.__name__}" with patch(patch_path, side_effect=fn_to_patch) as mock: # When scale (and z.p) are precomputed, all inputs are assumed to be already reshaped and reduction @@ -167,13 +180,17 @@ def test_quantization_alignment(weight_shape, config, quantization_task, tensor_ if quantization_task == QuantizationTask.Q_DQ_RQ: kwargs["return_compressed_weight"] = True - outputs = fn_to_call( - weight, config, reduction_axes, precomputed_scale, precomputed_zero_point, **kwargs - ) + args = (weight, config, reduction_axes, precomputed_scale) + if config.is_integer: + args = args + (precomputed_zero_point,) + outputs = fn_to_call(*args, **kwargs) decompressed_weight, compressed_weight, scale, zero_point = (None,) * 4 if quantization_task == QuantizationTask.Q: - compressed_weight, scale, zero_point = outputs + if config.is_integer: + compressed_weight, scale, zero_point = outputs + else: + compressed_weight, scale = outputs elif quantization_task == QuantizationTask.Q_DQ: decompressed_weight = outputs else: @@ -263,22 +280,30 @@ def _check_backends_and_dtypes( quantization_task == QuantizationTask.Q and cb == ComputationBackend.OV and weight_tensor_backend == TensorBackend.ov - and config.num_bits == 4 + and config.mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM, CompressWeightsMode.NF4] ): # For 4 bit compression in case of ov implementation and ov backend the compressed weight and the computed - # zero point must be in ov backend and have (u)int4 dtype in order to be able to insert them into OV model - # without re-packing + # zero point must be in ov backend and have (u)int4 or nf4 dtypes in order to be able to insert them into OV + # model without re-packing + if config.is_integer: + ref_dtype = TensorDataType.uint4 if config.is_asym_mode else TensorDataType.int4 + else: + ref_dtype = TensorDataType.nf4 assert compressed_weight.backend == TensorBackend.ov - assert compressed_weight.dtype == (TensorDataType.uint4 if config.is_asym_mode else TensorDataType.int4) + assert compressed_weight.dtype == ref_dtype if config.is_asym_mode and not precompute_s_zp: assert zero_point.backend == TensorBackend.ov assert zero_point.dtype == TensorDataType.uint4 else: if quantization_task != QuantizationTask.Q_DQ: # Otherwise compressed weight and zero point must be returned in numpy backend, compressed weight must - # be of (u)int8 data type, zero point -- in int32 + # be of (u)int8 or float32 data type, zero point -- in int32 + if config.is_integer: + ref_dtype = TensorDataType.uint8 if config.is_asym_mode else TensorDataType.int8 + else: + ref_dtype = TensorDataType.float32 assert compressed_weight.backend == TensorBackend.numpy - assert compressed_weight.dtype == (TensorDataType.uint8 if config.is_asym_mode else TensorDataType.int8) + assert compressed_weight.dtype == ref_dtype if config.is_asym_mode and not precompute_s_zp: assert zero_point.backend == TensorBackend.numpy assert zero_point.dtype == TensorDataType.int32 diff --git a/tests/openvino/optimized_functions/test_ov_model_parameters.py b/tests/openvino/optimized_functions/test_ov_model_parameters.py index 56476c96270..799b2c3c1f3 100644 --- a/tests/openvino/optimized_functions/test_ov_model_parameters.py +++ b/tests/openvino/optimized_functions/test_ov_model_parameters.py @@ -17,9 +17,9 @@ from nncf.openvino.optimized_functions.models import OVModelParameters from nncf.openvino.optimized_functions.models import _infer_ov_model from nncf.openvino.optimized_functions.models import get_astype_model -from nncf.openvino.optimized_functions.models import get_compress_decompress_weight_model -from nncf.openvino.optimized_functions.models import get_compress_weight_model -from nncf.openvino.optimized_functions.models import get_quantization_error_model +from nncf.openvino.optimized_functions.models import get_integer_quantization_error_model +from nncf.openvino.optimized_functions.models import get_integer_quantization_model +from nncf.openvino.optimized_functions.models import get_integer_quantize_dequantize_weight_model from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.tensor import Tensor from nncf.tensor import TensorDataType @@ -44,7 +44,7 @@ def get(self, ov_model_params_kwargs=None, get_model_kwargs=None): MODEL_GETTERS = [ ModelGetter( - get_model_fn=get_compress_weight_model, + get_model_fn=get_integer_quantization_model, ov_model_params_kwargs=dict( input_dtypes={ "weight": TensorDataType.float32, @@ -61,7 +61,7 @@ def get(self, ov_model_params_kwargs=None, get_model_kwargs=None): ), ), ModelGetter( - get_model_fn=get_compress_weight_model, + get_model_fn=get_integer_quantization_model, ov_model_params_kwargs=dict( input_dtypes={"weight": TensorDataType.float32}, output_dtypes={ @@ -77,7 +77,7 @@ def get(self, ov_model_params_kwargs=None, get_model_kwargs=None): ), ), ModelGetter( - get_model_fn=get_compress_decompress_weight_model, + get_model_fn=get_integer_quantize_dequantize_weight_model, ov_model_params_kwargs=dict( input_dtypes={ "weight": TensorDataType.float32, @@ -96,7 +96,7 @@ def get(self, ov_model_params_kwargs=None, get_model_kwargs=None): ), ), ModelGetter( - get_model_fn=get_compress_decompress_weight_model, + get_model_fn=get_integer_quantize_dequantize_weight_model, ov_model_params_kwargs=dict( input_dtypes={ "weight": TensorDataType.float32, @@ -130,7 +130,7 @@ def get(self, ov_model_params_kwargs=None, get_model_kwargs=None): ), ), ModelGetter( - get_model_fn=get_quantization_error_model, + get_model_fn=get_integer_quantization_error_model, ov_model_params_kwargs=dict( input_dtypes={ "weight": TensorDataType.float32, @@ -228,9 +228,9 @@ def test_recompile(model_getter, recompile): model_getter.get() if recompile: ref_size = 0 - elif model_getter._get_model_fn == get_compress_decompress_weight_model: + elif model_getter._get_model_fn == get_integer_quantize_dequantize_weight_model: ref_size = 2 - elif model_getter._get_model_fn == get_quantization_error_model: + elif model_getter._get_model_fn == get_integer_quantization_error_model: ref_size = 3 else: ref_size = 1 @@ -335,6 +335,6 @@ def test_convertable_divison(weight, convertable_division, ref_compressed_weight weight = np.array(weight, np.float32) ref_compressed_weight = np.array(ref_compressed_weight, np.uint8) - model_run_fn = get_compress_weight_model(ov_model_params, config, weight.shape, reduction_axes=(1,)) + model_run_fn = get_integer_quantization_model(ov_model_params, config, weight.shape, reduction_axes=(1,)) compressed_weight = model_run_fn([Tensor(weight)])[0] np.testing.assert_allclose(compressed_weight.data, ref_compressed_weight, atol=0, rtol=0) From 70b7e82b885165ac205794061783cd35a845175b Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 21 Mar 2025 19:46:51 +0100 Subject: [PATCH 04/11] Add float_quantize_dequantize --- nncf/openvino/optimized_functions/__init__.py | 3 + .../openvino/optimized_functions/functions.py | 43 ++++++++++++ nncf/openvino/optimized_functions/models.py | 70 +++++++++++++++++++ .../algorithms/weight_compression/awq.py | 6 +- .../algorithms/weight_compression/gptq.py | 14 ++-- .../weight_compression/scale_estimation.py | 20 +++--- .../weight_compression/weight_lowering.py | 30 ++++++++ .../test_compression_functions.py | 17 +++-- tests/openvino/requirements.txt | 1 + tests/post_training/requirements.txt | 1 + 10 files changed, 176 insertions(+), 29 deletions(-) diff --git a/nncf/openvino/optimized_functions/__init__.py b/nncf/openvino/optimized_functions/__init__.py index 82117571732..fb7675ee28f 100644 --- a/nncf/openvino/optimized_functions/__init__.py +++ b/nncf/openvino/optimized_functions/__init__.py @@ -12,6 +12,9 @@ from nncf.openvino.optimized_functions.functions import astype as astype from nncf.openvino.optimized_functions.functions import do_float_quantization as do_float_quantization from nncf.openvino.optimized_functions.functions import do_integer_quantization as do_integer_quantization +from nncf.openvino.optimized_functions.functions import ( + float_quantize_dequantize_weight as float_quantize_dequantize_weight, +) from nncf.openvino.optimized_functions.functions import get_integer_quantization_error as get_integer_quantization_error from nncf.openvino.optimized_functions.functions import ( integer_quantize_dequantize_weight as integer_quantize_dequantize_weight, diff --git a/nncf/openvino/optimized_functions/functions.py b/nncf/openvino/optimized_functions/functions.py index c7a65e67655..6c352ffcfbc 100644 --- a/nncf/openvino/optimized_functions/functions.py +++ b/nncf/openvino/optimized_functions/functions.py @@ -17,6 +17,7 @@ from nncf.openvino.optimized_functions.models import OVModelParameters from nncf.openvino.optimized_functions.models import get_astype_model from nncf.openvino.optimized_functions.models import get_float_quantization_model +from nncf.openvino.optimized_functions.models import get_float_quantize_dequantize_weight_model from nncf.openvino.optimized_functions.models import get_integer_quantization_error_model from nncf.openvino.optimized_functions.models import get_integer_quantization_model from nncf.openvino.optimized_functions.models import get_integer_quantize_dequantize_weight_model @@ -205,6 +206,48 @@ def integer_quantize_dequantize_weight( return decompressed_weight +def float_quantize_dequantize_weight( + weight: Tensor, + config: WeightCompressionConfig, + reduction_axes: Optional[ReductionAxes] = None, + precomputed_scale: Optional[Tensor] = None, + return_compressed_weight: Optional[bool] = False, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: + # When reduction axes are not provided, assuming that the weights are already reshaped + if config.group_size != -1 and reduction_axes is not None: + # weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2] + weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size) + + weight_shape = weight.shape + scale_shape = precomputed_scale.shape if precomputed_scale is not None else None + + ov_model_params = OVModelParameters() + ov_model_params.input_dtypes["weight"] = weight.dtype + if precomputed_scale is not None: + ov_model_params.input_dtypes["scale"] = precomputed_scale.dtype + + model = get_float_quantize_dequantize_weight_model( + ov_model_params, config, weight_shape, scale_shape, reduction_axes, return_compressed_weight + ) + + inputs = [weight] + if precomputed_scale is not None: + inputs.append(precomputed_scale) + + compressed_weight, scale = None, precomputed_scale + results = model(inputs) + if len(results) == 1: + decompressed_weight = results[0] + elif len(results) == 2: + decompressed_weight, compressed_weight = results + else: + decompressed_weight, compressed_weight, scale = results + if return_compressed_weight: + return decompressed_weight, compressed_weight, scale + else: + return decompressed_weight + + def get_integer_quantization_error( weight: Tensor, reduction_axes: ReductionAxes, diff --git a/nncf/openvino/optimized_functions/models.py b/nncf/openvino/optimized_functions/models.py index df4a95eb443..86a81e32484 100644 --- a/nncf/openvino/optimized_functions/models.py +++ b/nncf/openvino/optimized_functions/models.py @@ -254,6 +254,28 @@ def get_float_quantization_model( ) +def get_float_quantize_dequantize_weight_model( + ov_model_params: OVModelParameters, + config: WeightCompressionConfig, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, + return_compressed_weight: Optional[bool] = False, +) -> ModelCallable: + weight_shape, scale_shape, _ = _prepare_quantization_model_inputs( + ov_model_params, weight_shape, scale_shape, zero_point_shape=None, reduction_axes=reduction_axes + ) + + return _build_float_quantize_dequantize_weight_model( + config, + ov_model_params, + weight_shape, + scale_shape, + reduction_axes, + return_compressed_weight, + ) + + def get_integer_quantize_dequantize_weight_model( ov_model_params: OVModelParameters, config: WeightCompressionConfig, @@ -625,6 +647,54 @@ def _build_integer_quantize_dequantize_weight_model( return partial(_infer_ov_model, ov_model_params, compiled_model) +@cache_results(OV_MODEL_CACHE) +def _build_float_quantize_dequantize_weight_model( + config: WeightCompressionConfig, + ov_model_params: OVModelParameters, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, + return_compressed_weight: Optional[bool] = False, + return_nodes: Optional[bool] = False, +) -> Union[ModelCallable, ModelAsNodes]: + default_output_dtypes = {"decompressed_weight": TensorDataType.float32} + if not return_compressed_weight: + # If compressed weight is not returned to a user, we can keep it in float32 to avoid additional conversion + default_output_dtypes["compressed_weight"] = TensorDataType.float32 + ov_model_params = copy.deepcopy(ov_model_params) + ov_model_params.output_dtypes = {**default_output_dtypes, **ov_model_params.output_dtypes} + + decompressed_weight_dtype = ov_model_params.output_dtypes["decompressed_weight"] + if decompressed_weight_dtype != TensorDataType.float32: + msg = f"Decompressed weight must be of float32 data type. But found: {decompressed_weight_dtype}." + raise ValueError(msg) + + # Get compression model as input/result nodes and potentially modified ov model parameters + ov_parameters, ov_results, ov_model_params = _build_float_quantization_model( + config, ov_model_params, weight_shape, scale_shape, reduction_axes, return_nodes=True + ) + + if len(ov_parameters) == 1: + # weight -> compressed_weight, scale + compressed_weight, scale = ov_results + else: + # weight, scale -> compressed_weight + compressed_weight = ov_results[0] + scale = ov_parameters[1] + + decompressed_weight = opset.multiply(scale, convert_op(compressed_weight, ov.Type.f32)) + + ov_results = [decompressed_weight] + ov_results if return_compressed_weight else [decompressed_weight] + + if return_nodes: + return ov_parameters, ov_results, ov_model_params + + model = ov.Model(ov_results, ov_parameters) + compiled_model = _compile_ov_model(model, device_name="CPU", config={inference_precision(): ov.Type.f32}) + + return partial(_infer_ov_model, ov_model_params, compiled_model) + + @cache_results(OV_MODEL_CACHE) def _build_integer_quantization_error_model( config: WeightCompressionConfig, diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 8c8b5bbb6b0..fcd6acff813 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -30,8 +30,7 @@ from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.passes import transform_to_inference_graph from nncf.tensor import TensorDataType @@ -254,10 +253,9 @@ def apply( cur_scale = gscale**alpha weights_to_fake_quantize = gweight * cur_scale if config.mode == CompressWeightsMode.NF4: - g_compressed_weighs, g_c_scale = do_float_quantization( + g_decompressed_weighs = float_quantize_dequantize_weight( weights_to_fake_quantize, config, reduction_axis ) - g_decompressed_weighs = do_float_dequantization(g_compressed_weighs, g_c_scale) else: g_decompressed_weighs = integer_quantize_dequantize_weight( weights_to_fake_quantize, awq_config, reduction_axis diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 48bf458ce2f..d7331d0c2c6 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -28,8 +28,7 @@ from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -286,18 +285,17 @@ def _quantize_weights( zero_points.append(zero_point) if block_compression_config.mode == CompressWeightsMode.NF4: - compressed_weights, _ = do_float_quantization( - fns.unsqueeze(weight_col, 1), block_compression_config, precomputed_scale=scales[-1] + quantized_col = float_quantize_dequantize_weight( + fns.unsqueeze(weight_col, 1), + block_compression_config, + precomputed_scale=scales[-1], ) - quantized_col = do_float_dequantization(compressed_weights, scales[-1], reduction_axis=-1) else: - quantized_col, compressed_weights, _, _ = integer_quantize_dequantize_weight( + quantized_col = integer_quantize_dequantize_weight( fns.unsqueeze(weight_col, 1), block_compression_config, - reduction_axes=None, precomputed_scale=scales[-1], precomputed_zero_point=zero_points[-1], - return_compressed_weight=True, ) quantized_col = fns.flatten(quantized_col) quantized_block[:, i] = quantized_col diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 328e917c7c1..079efe6e7d9 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -23,9 +23,9 @@ from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization from nncf.tensor import Tensor @@ -198,9 +198,9 @@ def calculate_quantization_params( original_weight = fns.zeros_like(weight) + weight if config.mode == CompressWeightsMode.NF4: - compressed_weights, scale = do_float_quantization(original_weight, cur_config, reduction_axis) - q_weights = do_float_dequantization(compressed_weights, scale, reduction_axis) - q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) + q_weights, compressed_weights, scale = float_quantize_dequantize_weight( + original_weight, cur_config, reduction_axis, return_compressed_weight=True + ) zp = None else: q_weights, compressed_weights, scale, zp = integer_quantize_dequantize_weight( @@ -247,10 +247,11 @@ def calculate_quantization_params( near_to_ideal_scale = near_to_ideal_scale * scale_sign if config.mode == CompressWeightsMode.NF4: - g_compressed_weighs, _ = do_float_quantization( - original_weight, config, precomputed_scale=near_to_ideal_scale + out = float_quantize_dequantize_weight( + original_weight, + config, + precomputed_scale=near_to_ideal_scale, ) - out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale) else: out = integer_quantize_dequantize_weight( original_weight, @@ -318,10 +319,7 @@ def calculate_quantization_params( near_to_ideal_scale = near_to_ideal_scale * scale_sign if config.mode == CompressWeightsMode.NF4: - g_compressed_weighs, _ = do_float_quantization( - original_weight, config, precomputed_scale=near_to_ideal_scale - ) - out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale) + out = float_quantize_dequantize_weight(original_weight, config, precomputed_scale=near_to_ideal_scale) else: out = integer_quantize_dequantize_weight( original_weight, diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 65ee971e474..83befcbfcf7 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -149,6 +149,36 @@ def do_float_dequantization(compressed_weight: Tensor, scale: Tensor, reduction_ return decompressed_weight +def float_quantize_dequantize_weight( + weight: Tensor, + config: WeightCompressionConfig, + reduction_axes: Optional[ReductionAxes] = None, + precomputed_scale: Optional[Tensor] = None, + return_compressed_weight: Optional[bool] = False, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: + # Optimized implementation + if _can_run_optimized(weight.backend): + from nncf.openvino.optimized_functions import ( + float_quantize_dequantize_weight as float_quantize_dequantize_weight_ov, + ) + + return float_quantize_dequantize_weight_ov( + weight, + config, + reduction_axes, + precomputed_scale, + return_compressed_weight, + ) + + # Reference implementation + compressed_weight, scale = do_float_quantization(weight, config, reduction_axes, precomputed_scale) + decompressed_weight = do_float_dequantization(compressed_weight, scale) + if return_compressed_weight: + return decompressed_weight, compressed_weight, scale + else: + return decompressed_weight + + def do_float_quantization( weight: Tensor, config: WeightCompressionConfig, diff --git a/tests/openvino/optimized_functions/test_compression_functions.py b/tests/openvino/optimized_functions/test_compression_functions.py index f047b6dccea..c5dbd69c332 100644 --- a/tests/openvino/optimized_functions/test_compression_functions.py +++ b/tests/openvino/optimized_functions/test_compression_functions.py @@ -25,6 +25,7 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization @@ -65,7 +66,6 @@ class QuantizationTask(Enum): ] COMPRESSION_CONFIGS = INT8_COMPRESSION_CONFIGS + INT4_COMPRESSION_CONFIGS + FP4_COMPRESSION_CONFIGS -# COMPRESSION_CONFIGS = FP4_COMPRESSION_CONFIGS WEIGHT_SHAPE = (10000, 4) @@ -166,10 +166,12 @@ def test_quantization_alignment(weight_shape, config, quantization_task, tensor_ fn_to_call = do_float_quantization fn_to_patch = opt_fns.do_float_quantization else: - if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: - pytest.skip("Quantize-dequantize is not supported for NF4 and E2M1 modes") - fn_to_call = integer_quantize_dequantize_weight - fn_to_patch = opt_fns.integer_quantize_dequantize_weight + if config.is_integer: + fn_to_call = integer_quantize_dequantize_weight + fn_to_patch = opt_fns.integer_quantize_dequantize_weight + else: + fn_to_call = float_quantize_dequantize_weight + fn_to_patch = opt_fns.float_quantize_dequantize_weight patch_path = f"nncf.openvino.optimized_functions.{fn_to_patch.__name__}" with patch(patch_path, side_effect=fn_to_patch) as mock: # When scale (and z.p) are precomputed, all inputs are assumed to be already reshaped and reduction @@ -194,7 +196,10 @@ def test_quantization_alignment(weight_shape, config, quantization_task, tensor_ elif quantization_task == QuantizationTask.Q_DQ: decompressed_weight = outputs else: - decompressed_weight, compressed_weight, scale, zero_point = outputs + if config.is_integer: + decompressed_weight, compressed_weight, scale, zero_point = outputs + else: + decompressed_weight, compressed_weight, scale = outputs if cb == ComputationBackend.NumPy: mock.assert_not_called() diff --git a/tests/openvino/requirements.txt b/tests/openvino/requirements.txt index dd2a938bb4f..c3141b2fc42 100644 --- a/tests/openvino/requirements.txt +++ b/tests/openvino/requirements.txt @@ -1,3 +1,4 @@ +-i https://storage.openvinotoolkit.org/simple/wheels/pre-release -c ../../constraints.txt fastdownload==0.0.7 onnx diff --git a/tests/post_training/requirements.txt b/tests/post_training/requirements.txt index 2e67b7e1c55..45cee4e2845 100644 --- a/tests/post_training/requirements.txt +++ b/tests/post_training/requirements.txt @@ -1,3 +1,4 @@ +-i https://storage.openvinotoolkit.org/simple/wheels/pre-release -c ../../constraints.txt torch torchvision From 6598ed2df707822ca631d52f56c48df83eb946a6 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 21 Mar 2025 19:48:49 +0100 Subject: [PATCH 05/11] Undo irrelevant changes --- nncf/tensor/functions/__init__.py | 1 - nncf/tensor/functions/numeric.py | 10 ---------- nncf/tensor/functions/numpy_numeric.py | 5 ----- nncf/tensor/functions/openvino_numeric.py | 2 +- nncf/tensor/functions/torch_numeric.py | 5 ----- nncf/version.py | 2 +- 6 files changed, 2 insertions(+), 23 deletions(-) diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index 050f5a39ef1..568a4444ffc 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -54,7 +54,6 @@ from nncf.tensor.functions.numeric import reshape as reshape from nncf.tensor.functions.numeric import round as round from nncf.tensor.functions.numeric import searchsorted as searchsorted -from nncf.tensor.functions.numeric import sign as sign from nncf.tensor.functions.numeric import squeeze as squeeze from nncf.tensor.functions.numeric import stack as stack from nncf.tensor.functions.numeric import sum as sum diff --git a/nncf/tensor/functions/numeric.py b/nncf/tensor/functions/numeric.py index 3eb3941221c..d1dc51577b5 100644 --- a/nncf/tensor/functions/numeric.py +++ b/nncf/tensor/functions/numeric.py @@ -105,16 +105,6 @@ def abs(a: Tensor) -> Tensor: """ -@tensor_dispatcher -def sign(a: Tensor) -> Tensor: - """ - Calculate the sign value element-wise. - - :param a: The input tensor. - :return: A tensor containing the sign value of each element in x. - """ - - @tensor_dispatcher def astype(a: Tensor, dtype: TensorDataType) -> Tensor: """ diff --git a/nncf/tensor/functions/numpy_numeric.py b/nncf/tensor/functions/numpy_numeric.py index 38b33d8d124..f52f4a69470 100644 --- a/nncf/tensor/functions/numpy_numeric.py +++ b/nncf/tensor/functions/numpy_numeric.py @@ -87,11 +87,6 @@ def _(a: T_NUMPY) -> T_NUMPY: return np.absolute(a) -@numeric.sign.register -def _(a: T_NUMPY) -> T_NUMPY: - return np.sign(a) - - @numeric.astype.register def _(a: T_NUMPY, dtype: TensorDataType) -> T_NUMPY: return a.astype(DTYPE_MAP[dtype]) diff --git a/nncf/tensor/functions/openvino_numeric.py b/nncf/tensor/functions/openvino_numeric.py index 386ec1a1093..e6eedc3f13a 100644 --- a/nncf/tensor/functions/openvino_numeric.py +++ b/nncf/tensor/functions/openvino_numeric.py @@ -83,7 +83,7 @@ def _(a: ov.Tensor, shape: Union[int, Tuple[int, ...]]) -> ov.Tensor: @numeric.as_numpy_tensor.register def _(a: ov.Tensor) -> NDArray[Any]: - # Cannot convert bfloat16, uint4, int4, nf4, f4e2m1, f8e4m3, f8e5m2 to numpy directly + # Cannot convert bfloat16, uint4, int4, nf4, f8e4m3, f8e5m2 to numpy directly a_dtype = DTYPE_MAP_REV[a.get_element_type()] if a_dtype in [ TensorDataType.bfloat16, diff --git a/nncf/tensor/functions/torch_numeric.py b/nncf/tensor/functions/torch_numeric.py index 3c990e7adad..905fe916014 100644 --- a/nncf/tensor/functions/torch_numeric.py +++ b/nncf/tensor/functions/torch_numeric.py @@ -98,11 +98,6 @@ def _(a: torch.Tensor) -> torch.Tensor: return torch.absolute(a) -@numeric.abs.register -def _(a: torch.Tensor) -> torch.Tensor: - return torch.sign(a) - - @numeric.astype.register def _(a: torch.Tensor, dtype: TensorDataType) -> torch.Tensor: return a.type(DTYPE_MAP[dtype]) diff --git a/nncf/version.py b/nncf/version.py index 95a9f571d48..f7b5b2206e3 100644 --- a/nncf/version.py +++ b/nncf/version.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.16.0.dev0+fb70c8272dirty" +__version__ = "2.16.0" BKC_TORCH_SPEC = "==2.6.*" From 7b987a4aa4a49c443b4b784ce6acde85196ad122 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 21 Mar 2025 20:02:24 +0100 Subject: [PATCH 06/11] Tweak install instructions --- .github/workflows/conformance_weight_compression.yml | 2 +- Makefile | 2 +- nncf/openvino/optimized_functions/models.py | 10 ---------- nncf/version.py | 2 +- tests/openvino/requirements.txt | 1 - tests/post_training/requirements.txt | 1 - 6 files changed, 3 insertions(+), 15 deletions(-) diff --git a/.github/workflows/conformance_weight_compression.yml b/.github/workflows/conformance_weight_compression.yml index 4e443a8a525..dd89693e62c 100644 --- a/.github/workflows/conformance_weight_compression.yml +++ b/.github/workflows/conformance_weight_compression.yml @@ -40,7 +40,7 @@ jobs: - name: cpuinfo run: cat /proc/cpuinfo - name: Install NNCF and test requirements - run: pip install -e . -r tests/post_training/requirements.txt + run: pip install -e . -r tests/post_training/requirements.txt --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release - name: Print installed modules run: pip list - name: Run examples test scope diff --git a/Makefile b/Makefile index 05e9b70b841..52ec4805643 100644 --- a/Makefile +++ b/Makefile @@ -64,7 +64,7 @@ install-openvino-test: pip install -U pip pip install -e . pip install "git+https://github.com/openvinotoolkit/open_model_zoo.git@e7df86da686d2e1600282422e54f66c2fecea160#egg=accuracy_checker&subdirectory=tools/accuracy_checker" - pip install -r tests/openvino/requirements.txt + pip install -r tests/openvino/requirements.txt --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release pip install -r tests/cross_fw/install/requirements.txt pip install -r tests/cross_fw/examples/requirements.txt diff --git a/nncf/openvino/optimized_functions/models.py b/nncf/openvino/optimized_functions/models.py index 86a81e32484..95310631d1d 100644 --- a/nncf/openvino/optimized_functions/models.py +++ b/nncf/openvino/optimized_functions/models.py @@ -22,7 +22,6 @@ from openvino.runtime import Node from openvino.runtime import opset13 as opset -from nncf import CompressWeightsMode from nncf.common.utils.backend import is_openvino_at_least from nncf.common.utils.caching import ResultsCache from nncf.common.utils.caching import cache_results @@ -561,15 +560,6 @@ def _build_float_quantization_model( eps = np.finfo(np.float32).eps scale = opset.select(opset.less(opset.abs(scale), eps), eps, scale) - if config.mode == CompressWeightsMode.E2M1: - max_val = opset.constant(6, ov.Type.f32) # Maximal value of e2m1 type. - constant_2 = opset.constant(2, ov.Type.f32) - scale = divide_op(scale, max_val) - scale = opset.log(scale) / opset.log(constant_2) - scale = opset.ceil(scale) - scale = opset.clamp(scale, -127, 127) - scale = opset.power(constant_2, scale) - compressed_weight = divide_op(weight, scale) compressed_weight = convert_op(compressed_weight, ov.Type.nf4) compressed_weight = convert_op(compressed_weight, DTYPE_MAP_OV[compressed_weight_dtype]) diff --git a/nncf/version.py b/nncf/version.py index f7b5b2206e3..c6b875d4f9b 100644 --- a/nncf/version.py +++ b/nncf/version.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.16.0" +__version__ = "2.16.0.dev0+6598ed2dfdirty" BKC_TORCH_SPEC = "==2.6.*" diff --git a/tests/openvino/requirements.txt b/tests/openvino/requirements.txt index c3141b2fc42..dd2a938bb4f 100644 --- a/tests/openvino/requirements.txt +++ b/tests/openvino/requirements.txt @@ -1,4 +1,3 @@ --i https://storage.openvinotoolkit.org/simple/wheels/pre-release -c ../../constraints.txt fastdownload==0.0.7 onnx diff --git a/tests/post_training/requirements.txt b/tests/post_training/requirements.txt index 45cee4e2845..2e67b7e1c55 100644 --- a/tests/post_training/requirements.txt +++ b/tests/post_training/requirements.txt @@ -1,4 +1,3 @@ --i https://storage.openvinotoolkit.org/simple/wheels/pre-release -c ../../constraints.txt torch torchvision From fb04eeaa61e5619f715c1efed48ea47fc661f403 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 21 Mar 2025 20:05:10 +0100 Subject: [PATCH 07/11] Undo version changes --- nncf/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/version.py b/nncf/version.py index c6b875d4f9b..f7b5b2206e3 100644 --- a/nncf/version.py +++ b/nncf/version.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.16.0.dev0+6598ed2dfdirty" +__version__ = "2.16.0" BKC_TORCH_SPEC = "==2.6.*" From 8cbd76979007955277d2835abcf0000286c5b373 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 24 Mar 2025 11:31:09 +0100 Subject: [PATCH 08/11] Temporarily update yml files --- .github/workflows/call_precommit.yml | 2 +- .github/workflows/conformance_weight_compression.yml | 1 + Makefile | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/call_precommit.yml b/.github/workflows/call_precommit.yml index 4c025bf83f7..59bc2eb90df 100644 --- a/.github/workflows/call_precommit.yml +++ b/.github/workflows/call_precommit.yml @@ -90,7 +90,7 @@ jobs: run: python .github/scripts/override_constraints.py "${{ inputs.override_requirements }}" shell: bash - name: Install NNCF and test requirements - run: pip install . -r tests/openvino/requirements.txt + run: pip install . -r tests/openvino/requirements.txt --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release - name: Print installed modules run: pip list - name: Run OV precommit test scope diff --git a/.github/workflows/conformance_weight_compression.yml b/.github/workflows/conformance_weight_compression.yml index dd89693e62c..37ad8d68f46 100644 --- a/.github/workflows/conformance_weight_compression.yml +++ b/.github/workflows/conformance_weight_compression.yml @@ -2,6 +2,7 @@ name: Weight compression permissions: read-all on: + pull_request: workflow_call: workflow_dispatch: inputs: diff --git a/Makefile b/Makefile index 52ec4805643..05e9b70b841 100644 --- a/Makefile +++ b/Makefile @@ -64,7 +64,7 @@ install-openvino-test: pip install -U pip pip install -e . pip install "git+https://github.com/openvinotoolkit/open_model_zoo.git@e7df86da686d2e1600282422e54f66c2fecea160#egg=accuracy_checker&subdirectory=tools/accuracy_checker" - pip install -r tests/openvino/requirements.txt --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release + pip install -r tests/openvino/requirements.txt pip install -r tests/cross_fw/install/requirements.txt pip install -r tests/cross_fw/examples/requirements.txt From ea0059dda23556ea2b259126498641c8f93d27b3 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 24 Mar 2025 11:43:08 +0100 Subject: [PATCH 09/11] Update yml --- .github/workflows/call_precommit.yml | 4 +++- .github/workflows/conformance_weight_compression.yml | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/call_precommit.yml b/.github/workflows/call_precommit.yml index 59bc2eb90df..5508fd63fd8 100644 --- a/.github/workflows/call_precommit.yml +++ b/.github/workflows/call_precommit.yml @@ -90,7 +90,9 @@ jobs: run: python .github/scripts/override_constraints.py "${{ inputs.override_requirements }}" shell: bash - name: Install NNCF and test requirements - run: pip install . -r tests/openvino/requirements.txt --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release + run: pip install . -r tests/openvino/requirements.txt + - name: Install OV RC + run: pip install -U --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - name: Print installed modules run: pip list - name: Run OV precommit test scope diff --git a/.github/workflows/conformance_weight_compression.yml b/.github/workflows/conformance_weight_compression.yml index 37ad8d68f46..0e7896d879e 100644 --- a/.github/workflows/conformance_weight_compression.yml +++ b/.github/workflows/conformance_weight_compression.yml @@ -41,7 +41,9 @@ jobs: - name: cpuinfo run: cat /proc/cpuinfo - name: Install NNCF and test requirements - run: pip install -e . -r tests/post_training/requirements.txt --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release + run: pip install -e . -r tests/post_training/requirements.txt + - name: Install OV RC + run: pip install -U --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - name: Print installed modules run: pip list - name: Run examples test scope From f172c9f7cfe4c12d57d15b80c68e4d5091ddef8a Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 24 Mar 2025 11:43:56 +0100 Subject: [PATCH 10/11] Update yml --- .github/workflows/call_precommit.yml | 2 +- .github/workflows/conformance_weight_compression.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/call_precommit.yml b/.github/workflows/call_precommit.yml index 5508fd63fd8..4fffefbbf85 100644 --- a/.github/workflows/call_precommit.yml +++ b/.github/workflows/call_precommit.yml @@ -92,7 +92,7 @@ jobs: - name: Install NNCF and test requirements run: pip install . -r tests/openvino/requirements.txt - name: Install OV RC - run: pip install -U --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + run: pip install -U --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release - name: Print installed modules run: pip list - name: Run OV precommit test scope diff --git a/.github/workflows/conformance_weight_compression.yml b/.github/workflows/conformance_weight_compression.yml index 0e7896d879e..aa5548eb631 100644 --- a/.github/workflows/conformance_weight_compression.yml +++ b/.github/workflows/conformance_weight_compression.yml @@ -43,7 +43,7 @@ jobs: - name: Install NNCF and test requirements run: pip install -e . -r tests/post_training/requirements.txt - name: Install OV RC - run: pip install -U --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + run: pip install -U --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release - name: Print installed modules run: pip list - name: Run examples test scope From 043252739a92477d3db91d2f9341d08b70dad456 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 24 Mar 2025 11:52:59 +0100 Subject: [PATCH 11/11] Also install other openvino packages --- .github/workflows/conformance_weight_compression.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/conformance_weight_compression.yml b/.github/workflows/conformance_weight_compression.yml index aa5548eb631..61da775449d 100644 --- a/.github/workflows/conformance_weight_compression.yml +++ b/.github/workflows/conformance_weight_compression.yml @@ -43,7 +43,7 @@ jobs: - name: Install NNCF and test requirements run: pip install -e . -r tests/post_training/requirements.txt - name: Install OV RC - run: pip install -U --pre openvino --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release + run: pip install -U --pre openvino openvino-tokenizers openvino-genai --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release - name: Print installed modules run: pip list - name: Run examples test scope