From 488cacc2be70b7ae7e417c555d2aeea29163f5b6 Mon Sep 17 00:00:00 2001 From: Aleksandr Suslov Date: Mon, 10 Jun 2024 19:17:08 +0400 Subject: [PATCH 1/8] Support scale estimation inside GPTQ --- .../algorithms/layerwise/scheduler.py | 34 +- .../weight_compression/activation_stats.py | 7 +- .../weight_compression/algorithm.py | 59 ++-- .../algorithms/weight_compression/gptq.py | 41 ++- .../weight_compression/scale_estimation.py | 316 ++++++++++-------- nncf/quantization/quantize_model.py | 5 - .../openvino/native/quantization/test_gptq.py | 5 +- .../quantization/test_weights_compression.py | 5 +- 8 files changed, 271 insertions(+), 201 deletions(-) diff --git a/nncf/quantization/algorithms/layerwise/scheduler.py b/nncf/quantization/algorithms/layerwise/scheduler.py index 8eee99fad28..8abc03400c0 100644 --- a/nncf/quantization/algorithms/layerwise/scheduler.py +++ b/nncf/quantization/algorithms/layerwise/scheduler.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass from dataclasses import field @@ -177,26 +178,31 @@ def schedule( old_input_nodes = set() new_input_nodes = set() for p in paths: - target_output_nodes = set() + target_outputs = [] additional_output_nodes = set() for output_node in p.output_nodes: - if output_node in target_nodes: - target_output_nodes.add(output_node) - elif output_node in p.input_nodes: - reuse_input_nodes.add(output_node) - else: - # filter additional output nodes - for prev_node in inference_graph.get_previous_nodes(output_node): - if prev_node not in p.output_nodes: - additional_output_nodes.add(output_node) - break - if not target_output_nodes: + try: + target_node_index = target_nodes.index(output_node) + target_outputs.append((target_node_index, output_node)) + except ValueError: + if output_node in p.input_nodes: + reuse_input_nodes.add(output_node) + else: + # filter additional output nodes + for prev_node in inference_graph.get_previous_nodes(output_node): + if prev_node not in p.output_nodes: + additional_output_nodes.add(output_node) + break + if not target_outputs: continue + target_outputs.sort(key=lambda target_output: target_output[0]) + target_output_nodes = [output[1] for output in target_outputs] + old_input_nodes |= p.input_nodes - new_input_nodes |= target_output_nodes | additional_output_nodes + new_input_nodes |= set(target_output_nodes) | additional_output_nodes subgraph_inputs = list(p.inputs) - step_target_nodes = {} + step_target_nodes = OrderedDict() subgraph_outputs = [] for node in target_output_nodes: target_edge = {} diff --git a/nncf/quantization/algorithms/weight_compression/activation_stats.py b/nncf/quantization/algorithms/weight_compression/activation_stats.py index eb8286e6383..359887e7769 100644 --- a/nncf/quantization/algorithms/weight_compression/activation_stats.py +++ b/nncf/quantization/algorithms/weight_compression/activation_stats.py @@ -9,14 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, TypeVar +from typing import List, Tuple +from nncf.tensor import Tensor from nncf.tensor import functions as fns -TTensor = TypeVar("TTensor") - -def process_stats(stats: List[TTensor], subset_size: int) -> Tuple[TTensor, TTensor]: +def process_stats(stats: List[Tensor], subset_size: int) -> Tuple[Tensor, Tensor]: """ It's a processing of activations shared between AWQ, Scale Estimation and LoRA Correction algorithms. diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 3499521bce3..1b2af0fd9a3 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -124,7 +124,12 @@ def __init__( if self._gptq: gptq_params = self._advanced_parameters.gptq_params - self._gptq_algo = GPTQ(gptq_params.damp_percent, gptq_params.block_size, gptq_params.subset_size) + self._gptq_algo = GPTQ( + damp_percent=gptq_params.damp_percent, + block_size=gptq_params.block_size, + subset_size=gptq_params.subset_size, + scale_estimation=self._scale_estimation, + ) self._gptq_statistics = None @property @@ -379,25 +384,8 @@ def apply( scales = {} zero_points = {} - if ( - self._scale_estimation - and activations is not None - and self._mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] - ): - scale_estimation_params = self._advanced_parameters.scale_estimation_params - scale_algo = ScaleEstimation( - model, - self._backend_entity.name_to_node_mapping, - all_weight_params, - nodes_to_compress, - activations, - scale_estimation_params.subset_size, - scale_estimation_params.initial_steps, - scale_estimation_params.scale_steps, - scale_estimation_params.weight_penalty, - ) - scales = scale_algo.apply(model, graph) - + lora_correction_algo = None + description = "Applying Weight Compression" if self._gptq: model, scales, zero_points = self._gptq_algo.apply( model=model, @@ -407,13 +395,30 @@ def apply( statistic_points=self._gptq_statistics, backend_entity=self._backend_entity, ) + else: + if ( + self._scale_estimation + and activations is not None + and self._mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] + ): + scale_estimation_params = self._advanced_parameters.scale_estimation_params + scale_algo = ScaleEstimation( + model, + self._backend_entity.name_to_node_mapping, + all_weight_params, + nodes_to_compress, + activations, + scale_estimation_params.subset_size, + scale_estimation_params.initial_steps, + scale_estimation_params.scale_steps, + scale_estimation_params.weight_penalty, + ) + scales = scale_algo.apply(model, graph) - lora_correction_algo = None - description = "Applying Weight Compression" - if self._lora_correction: - lora_correction_params = self._advanced_parameters.lora_correction_params - lora_correction_algo = LoraCorrectionAlgorithm(activations, lora_correction_params) - description += " with correction of low-rank adapters" + if self._lora_correction: + lora_correction_params = self._advanced_parameters.lora_correction_params + lora_correction_algo = LoraCorrectionAlgorithm(activations, lora_correction_params) + description += " with correction of low-rank adapters" # Sort weight params to start compression with the bigger constants. This lowers peak memory footprint. all_weight_params = sorted(all_weight_params, key=lambda wp: wp.num_weights, reverse=True) @@ -542,7 +547,7 @@ def _get_activations( statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) statistics_aggregator.register_statistic_points(statistic_container) - if self._gptq: + if self._gptq and not self._awq: self._gptq_statistics = self._gptq_algo.get_statistic_points( model, graph, nodes_to_compress, self._backend_entity ) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index b595e080533..b1101916da3 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -25,6 +25,7 @@ from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_quantized_weight @@ -44,10 +45,7 @@ class GPTQ: """ def __init__( - self, - damp_percent: float = 0.1, - block_size: int = 128, - subset_size: int = 128, + self, damp_percent: float = 0.1, block_size: int = 128, subset_size: int = 128, scale_estimation: bool = False ): """ :param damp_percent: The percent of the average Hessian diagonal to use for dampening, @@ -58,6 +56,7 @@ def __init__( self._damp_percent = damp_percent self._block_size = block_size self._subset_size = subset_size + self._scale_estimation = scale_estimation self._backend = None self._backend_entity = None @@ -124,10 +123,9 @@ def apply( CompressWeightsMode.INT8_SYM, ]: continue - assert len(inputs) == 1 _, input_tensors = next(iter(inputs.items())) hessian = self._calculate_hessian(node, input_tensors) - scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian) + scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors) scales[wc_params.weight_name] = scale zero_points[wc_params.weight_name] = zero_point @@ -193,7 +191,12 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: return hessian def _quantize_weights( - self, model: TModel, graph: NNCFGraph, wc_params: WeightCompressionParameters, hessian: Tensor + self, + model: TModel, + graph: NNCFGraph, + wc_params: WeightCompressionParameters, + hessian: Tensor, + inputs: List[Tensor], ): """ Quantizes the weights of the model based on the calculated Hessian matrix. @@ -260,11 +263,25 @@ def _quantize_weights( scale = calculate_nf4_scale(weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes) scales.append(scale) else: - scale, zero_point = calculate_integer_quantization_params( - weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, block_compression_config - ) - scales.append(scale) - zero_points.append(zero_point) + if self._scale_estimation and block_compression_config.num_bits == 4: + activations = [inp.squeeze()[:, (i1 + i) : (i1 + i + group_size)] for inp in inputs] + scale, zero_point = ScaleEstimation.calculate_quantization_params( + self._backend_entity, + activations, + weight_tensor[:, (i1 + i) : (i1 + i + group_size)], + reduction_axes, + wc_params.compression_config, + ) + scales.append(scale.squeeze(axis=1)) + zero_points.append(zero_point) + else: + scale, zero_point = calculate_integer_quantization_params( + weight_tensor[:, (i1 + i) : (i1 + i + group_size)], + reduction_axes, + block_compression_config, + ) + scales.append(scale) + zero_points.append(zero_point) if block_compression_config.mode == CompressWeightsMode.NF4: compressed_weights = do_nf4_quantization( fns.unsqueeze(weight_col, 1), scales[-1], is_normalized_weight=False diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 6d1110c108f..712c5fd955d 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -20,16 +20,17 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats +from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization +from nncf.tensor import Tensor from nncf.tensor import TensorDataType from nncf.tensor import functions as fns TModel = TypeVar("TModel") -TTensor = TypeVar("TTensor") -TWeightType = TypeVar("TWeightType") class ScaleEstimation: @@ -37,13 +38,15 @@ class ScaleEstimation: Scale estimation algorithm implementation. """ + compress_decompress_cache = {} + def __init__( self, model: TModel, name_to_node_mapping: Dict[str, Any], all_weight_params: List[WeightCompressionParameters], nodes_to_compress: List[NNCFNode], - activations: Optional[Dict[str, TTensor]] = None, + activations: Optional[Dict[str, List[Tensor]]] = None, subset_size: int = 32, initial_steps: int = 5, scale_steps: int = 10, @@ -103,7 +106,7 @@ def apply( graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> Dict[str, TTensor]: + ) -> Dict[str, Tensor]: """ Estimates better scale for the int4 nodes in the model. Minimizes per-group difference between floating point MatMul and @@ -118,8 +121,7 @@ def apply( :return: Dict with pairs (weight name, estimated scale). """ - compress_decompress_cache = {} - res = dict() + scales = dict() for wp in track(self._all_weight_params, description="Applying Scale Estimation"): weight_name = wp.weight_name @@ -127,11 +129,10 @@ def apply( config = wp.compression_config if config.num_bits != 4 or node_name not in self._activations: - res[weight_name] = None + scales[weight_name] = None continue - s, X = process_stats(self._activations[node_name], self._subset_size) - reduction_axis = wp.reduction_axes[0] + stats = self._activations[node_name] weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) if len(weight_data) != 1: # not supported by the algorithm @@ -139,162 +140,211 @@ def apply( _, weight_port_id = weight_data[0] weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) - weight = weight.astype(TensorDataType.float32) - eps = fns.finfo(weight).eps - if reduction_axis == 0: - weight = fns.transpose(weight) - reduction_axis = 1 + scales[weight_name], _ = self.calculate_quantization_params( + self._backend_entity, + stats, + weight, + wp.reduction_axes, + config, + self._subset_size, + self._initial_steps, + self._scale_steps, + self._weight_penalty, + ) - group_size = config.group_size if config.group_size != -1 else weight.shape[reduction_axis] - cur_config = deepcopy(config) - cur_config.group_size = group_size + return scales - original_weight = fns.zeros_like(weight) + weight + @staticmethod + def calculate_quantization_params( + backend_entity: WeightCompressionAlgoBackend, + activations: List[Tensor], + weight: Tensor, + reduction_axes: Tuple[int, ...], + config: WeightCompressionConfig, + subset_size: int = 32, + initial_steps: int = 5, + scale_steps: int = 10, + weight_penalty: float = -1.0, + ) -> Tensor: + """ + Calculates the quantization parameters for a given set of weights and activations. + This function estimates the optimal quantization scale for weight compression by + minimizing the difference between floating-point operations and operations with + quantized weights. + + The function uses an iterative process: + 1. Initial scale rectification based on activation statistics. + 2. A grid search to further refine the scale parameters. + + :param backend_entity: The backend-specific implementation of the weight compression algorithm. + :param activations: List of activation tensors corresponding to the layers being quantized. + :param weight: The weight tensor that is being quantized. + :param reduction_axes: Tuple specifying the axes along which the reduction is performed for quantization. + :param config: Configuration parameters for the weight compression, including quantization settings. + :param subset_size: The number of samples to use for scale estimation. Defaults to 32. + :param initial_steps: The number of steps for initial scale rectification using activation statistics. + Defaults to 5. + :param scale_steps: The number of steps for refining the scale using a grid search. Defaults to 10. + :param weight_penalty: Penalty coefficient applied to the difference between floating-point + and quantized weights. A value of -1 disables the penalty. Defaults to -1.0. + :return: A tensor containing the calculated quantization scales and zero points if applicable. + """ + reduction_axis = reduction_axes[0] - compressed_weights, scale, zp = do_int_quantization(original_weight, reduction_axis, cur_config) - if zp is not None: - zp = zp.astype(scale.dtype) - q_weights = do_int_dequantization(compressed_weights, scale, zp, reduction_axis) + s, X = process_stats(activations, subset_size) - s = fns.unsqueeze(s, 0) - s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, group_size) + weight = weight.astype(TensorDataType.float32) + eps = fns.finfo(weight).eps - original_weight, _ = reshape_weight_for_grouped_quantization(original_weight, reduction_axis, group_size) + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 - # all weight in group has importance based on corresponding input activations - importance = fns.ones_like(original_weight) - importance = importance * s + group_size = config.group_size if config.group_size != -1 else weight.shape[reduction_axis] + cur_config = deepcopy(config) + cur_config.group_size = group_size - target, zero_mask = get_target_zero_mask(compressed_weights, zp) - importance = fns.where(zero_mask, 0.0, importance) - - # normalize importances for every group of weights to make sum of them equal to 1.0 - denum = fns.sum(importance, axis=2, keepdims=True) - importance = importance / (denum + eps) - - X, _ = reshape_weight_for_grouped_quantization(X, 0, group_size) - q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) - best_diffs = None - result_scale = None - - fp_outs = fns.matmul(fns.transpose(original_weight, (1, 0, 2)), X) - q_outs = fns.matmul(fns.transpose(q_weights, (1, 0, 2)), X) - - # metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE - min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) - min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0)) - if self._weight_penalty > 0.0: - min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) - - zp_shape = zp.shape if zp is not None else None - key = [(wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape] - if zp is not None: - key += zp_shape - key = tuple(key) - if key in compress_decompress_cache: - compress_decompress_model = compress_decompress_cache[key]["compress_decompress_model"] - compress_model = compress_decompress_cache[key]["compress_model"] - else: - compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline( - wp.compression_config, q_weights.shape, scale.shape, zp_shape - ) - compress_model = self._backend_entity.get_compress_pipeline( - wp.compression_config, q_weights.shape, scale.shape, zp_shape - ) - compress_decompress_cache[key] = { - "compress_decompress_model": compress_decompress_model, - "compress_model": compress_model, - } - - scale_sign = scale / fns.abs(scale) - zero_scale = 0.001 - zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + original_weight = fns.zeros_like(weight) + weight - input_tensors = [original_weight.data, None] - if zp is not None: - input_tensors.append(zp.data) - # iterative rectification of initial scale - for i in range(self._initial_steps): - near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) - near_to_ideal_scale = near_to_ideal_scale * scale_sign - input_tensors[1] = near_to_ideal_scale.data + compressed_weights, scale, zp = do_int_quantization(original_weight, reduction_axis, cur_config) + if zp is not None: + zp = zp.astype(scale.dtype) + q_weights = do_int_dequantization(compressed_weights, scale, zp, reduction_axis) - out = compress_decompress_model(input_tensors) - q_weights_ = fns.zeros_like(original_weight) + out - q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) + s = fns.unsqueeze(s, 0) + s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, group_size) - ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) - ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) - if self._weight_penalty > 0.0: - ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) + original_weight, _ = reshape_weight_for_grouped_quantization(original_weight, reduction_axis, group_size) - if best_diffs is None: - best_diffs = min_max_scale_diffs + # all weight in group has importance based on corresponding input activations + importance = fns.ones_like(original_weight) + importance = importance * s - mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) + target, zero_mask = get_target_zero_mask(compressed_weights, zp) + importance = fns.where(zero_mask, 0.0, importance) - best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs + # normalize importances for every group of weights to make sum of them equal to 1.0 + denum = fns.sum(importance, axis=2, keepdims=True) + importance = importance / (denum + eps) - mask = fns.unsqueeze(mask, axis=2) + X, _ = reshape_weight_for_grouped_quantization(X, 0, group_size) + q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) + best_diffs = None + result_scale = None - if result_scale is None: - near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale - else: - near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale - result_scale = near_to_ideal_scale - input_tensors[1] = near_to_ideal_scale.data + fp_outs = fns.matmul(fns.transpose(original_weight, (1, 0, 2)), X) + q_outs = fns.matmul(fns.transpose(q_weights, (1, 0, 2)), X) - if i < self._initial_steps - 1: - out = compress_model(input_tensors) - compressed_weights = fns.zeros_like(original_weight) + out - target, zero_mask = get_target_zero_mask(compressed_weights, zp) - zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + # metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE + min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0)) + if weight_penalty > 0.0: + min_max_scale_diffs += weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) - # iterative rectification of scale based on grid search - for scale_steps in range(self._scale_steps): - factor = 1.0 - 0.05 * scale_steps - scaled_scale = factor * scale + zp_shape = zp.shape if zp is not None else None + key = (config.mode, config.num_bits) + q_weights.shape + scale.shape + if zp is not None: + key += zp_shape + if key in ScaleEstimation.compress_decompress_cache: + compress_decompress_model = ScaleEstimation.compress_decompress_cache[key]["compress_decompress_model"] + compress_model = ScaleEstimation.compress_decompress_cache[key]["compress_model"] + else: + compress_decompress_model = backend_entity.get_compress_decompress_pipeline( + config, q_weights.shape, scale.shape, zp_shape + ) + compress_model = backend_entity.get_compress_pipeline(config, q_weights.shape, scale.shape, zp_shape) + ScaleEstimation.compress_decompress_cache[key] = { + "compress_decompress_model": compress_decompress_model, + "compress_model": compress_model, + } + scale_sign = scale / fns.abs(scale) + zero_scale = 0.001 + zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + + input_tensors = [original_weight.data, None] + if zp is not None: + input_tensors.append(zp.data) + # iterative rectification of initial scale + for i in range(initial_steps): + near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) + near_to_ideal_scale = near_to_ideal_scale * scale_sign + input_tensors[1] = near_to_ideal_scale.data + + out = compress_decompress_model(input_tensors) + q_weights_ = fns.zeros_like(original_weight) + out + q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) + + ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) + if weight_penalty > 0.0: + ideal_scale_diffs += weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) + + if best_diffs is None: + best_diffs = min_max_scale_diffs + + mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) + + best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs + + mask = fns.unsqueeze(mask, axis=2) + + if result_scale is None: + near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale + else: + near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale + result_scale = near_to_ideal_scale + input_tensors[1] = near_to_ideal_scale.data - input_tensors[1] = scaled_scale.data + if i < initial_steps - 1: out = compress_model(input_tensors) compressed_weights = fns.zeros_like(original_weight) + out - target, zero_mask = get_target_zero_mask(compressed_weights, zp) zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) - near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) - near_to_ideal_scale = near_to_ideal_scale * scale_sign - input_tensors[1] = near_to_ideal_scale.data - out = compress_decompress_model(input_tensors) - q_weights_ = fns.zeros_like(original_weight) + out + # iterative rectification of scale based on grid search + for scale_steps in range(scale_steps): + factor = 1.0 - 0.05 * scale_steps + scaled_scale = factor * scale + + input_tensors[1] = scaled_scale.data + out = compress_model(input_tensors) + compressed_weights = fns.zeros_like(original_weight) + out - q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) - ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) - ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) - if self._weight_penalty > 0.0: - ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) + target, zero_mask = get_target_zero_mask(compressed_weights, zp) + zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) + near_to_ideal_scale = near_to_ideal_scale * scale_sign + + input_tensors[1] = near_to_ideal_scale.data + out = compress_decompress_model(input_tensors) + q_weights_ = fns.zeros_like(original_weight) + out - mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) + q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) + ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) + if weight_penalty > 0.0: + ideal_scale_diffs += weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) - best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs + mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) - mask = fns.unsqueeze(mask, axis=2) + best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs - if result_scale is None: - near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale - else: - near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale - result_scale = near_to_ideal_scale + mask = fns.unsqueeze(mask, axis=2) + + if result_scale is None: + near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale + else: + near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale + result_scale = near_to_ideal_scale - if config.group_size == -1: - result_scale = fns.squeeze(result_scale, axis=1) - res[weight_name] = result_scale + if config.group_size == -1: + result_scale = fns.squeeze(result_scale, axis=1) - return res + return result_scale, zp -def get_target_zero_mask(compressed_weights: TTensor, zp: Optional[TTensor] = None) -> Tuple[TTensor, TTensor]: +def get_target_zero_mask(compressed_weights: Tensor, zp: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: """ Computes the target values and a mask indicating zero values in the target. @@ -310,7 +360,7 @@ def get_target_zero_mask(compressed_weights: TTensor, zp: Optional[TTensor] = No return target, zero_mask -def estimate_scales(weight: TTensor, target: TTensor, zero_mask: TTensor, importance: TTensor) -> TTensor: +def estimate_scales(weight: Tensor, target: Tensor, zero_mask: Tensor, importance: Tensor) -> Tensor: """ Estimates scales for the given weight, target, zero mask, and importance. diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index e96c4526c51..60baeacc48e 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -482,11 +482,6 @@ def compress_weights( if any((gptq, lora_correction)) and (dataset is None or mode == CompressWeightsMode.E2M1): raise AttributeError("GPTQ or Lora Correction algorithm is defined, but dataset is None or mode is E2M1.") - if gptq and scale_estimation: - raise AttributeError( - "Simultaneous use of Scale estimation and GPTQ algorithms is not supported. Select one of them." - ) - if gptq and lora_correction: raise AttributeError( "Simultaneous use of Lora correction and GPTQ algorithms is not supported. Select one of them." diff --git a/tests/openvino/native/quantization/test_gptq.py b/tests/openvino/native/quantization/test_gptq.py index 1202b216ec7..ad19990eac0 100644 --- a/tests/openvino/native/quantization/test_gptq.py +++ b/tests/openvino/native/quantization/test_gptq.py @@ -341,7 +341,8 @@ def test_calculate_scale_linear(): gptq._set_backend_entity(ov_model) nodes = graph.get_all_nodes() - H = gptq._calculate_hessian(nodes[1], [Tensor(inp) for inp in inputs]) + wrapped_inputs = [Tensor(inp) for inp in inputs] + H = gptq._calculate_hessian(nodes[1], wrapped_inputs) ref_H = ref_gptq.H.numpy() assert np.all(np.isclose(ref_H, H.data)) @@ -351,7 +352,7 @@ def test_calculate_scale_linear(): ) wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16) - scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H) + scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs) ref_scale = ref_scale.numpy() scale = scale.reshape(ref_scale.shape) assert np.all(np.isclose(ref_scale, scale.data)) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index bb9b5c373c7..c51cf667ca2 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -713,10 +713,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): @pytest.mark.parametrize("mode", INT4_MODES) @pytest.mark.parametrize( "params", - ( - {"dataset": "anything", "scale_estimation": True, "gptq": True}, - {"dataset": "anything", "lora_correction": True, "gptq": True}, - ), + ({"dataset": "anything", "lora_correction": True, "gptq": True},), ) def test_raise_error_with_unsupported_params_for_int4(mode, params): with pytest.raises(AttributeError): From ee648777dcb951f4c7bdadd3997680a5083645a7 Mon Sep 17 00:00:00 2001 From: Aleksandr Suslov Date: Wed, 4 Sep 2024 13:25:22 +0400 Subject: [PATCH 2/8] fix for INT4_ASYM --- nncf/quantization/algorithms/weight_compression/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index b1101916da3..bd6518c86ad 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -273,7 +273,7 @@ def _quantize_weights( wc_params.compression_config, ) scales.append(scale.squeeze(axis=1)) - zero_points.append(zero_point) + zero_points.append(zero_point if zero_point is None else zero_point.squeeze(axis=1)) else: scale, zero_point = calculate_integer_quantization_params( weight_tensor[:, (i1 + i) : (i1 + i + group_size)], From f6f46934cbf83ae0dbf48ed3295ef274349c7e7d Mon Sep 17 00:00:00 2001 From: Andrei Anufriev Date: Tue, 25 Feb 2025 11:23:44 +0100 Subject: [PATCH 3/8] Data-free AWQ prototype. --- .../algorithms/weight_compression/awq.py | 271 +++++++++++++++--- 1 file changed, 224 insertions(+), 47 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 0b37a20c92e..ea4046071dc 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -132,63 +132,18 @@ def apply( :return: A resulting model. """ self._set_backend_entity(model, wc_backend_entity) - matches = [] - - inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) - nx_graph = inference_nncf_graph.get_nx_graph_copy() - for _, pattern_graph in self._patterns.items(): - matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False)) - if len(matches) == 0: - nncf_logger.info("No matching patterns were found for applying AWQ algorithm, it will be skipped.") + awq_data = self.get_awq_data(graph, all_weight_params, nodes_to_compress) + if len(awq_data) == 0: return model transformation_layout = TransformationLayout() model_transformer = ModelTransformerFactory.create(model, inplace=True) - awq_data = {} - name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)} - - for match in matches: - nncf_node = graph.get_node_by_key(match[-1]) - if not self._backend_entity.is_node_with_weights(nncf_node, graph): - continue - - target_node_names = [] - for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): - target_node_names.append(weight_op_friendly_name) - - # skip node if it is in IgnoredScope or should not be compressed - if target_node_names[-1] not in name_mapping: - continue - - weight_params = all_weight_params[name_mapping[target_node_names[-1]]] - - if weight_params.compression_config.num_bits != 4: - continue - target_node = nodes_to_compress[name_mapping[target_node_names[-1]]] - - # avoid matching different patterns for the same node - if target_node.node_name in awq_data: - continue - - nncf_node = graph.get_node_by_key(match[0]) - - if self._backend_entity.is_node_with_weights(nncf_node, graph): # pattern MatMul->Multiply->MatMul - merge_node_names = [] - for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): - merge_node_names.append(weight_op_friendly_name) - merge_node = nodes_to_compress[name_mapping[merge_node_names[-1]]] - else: # pattern Act->MatMul or Act->Multiply->MatMul - merge_node = nncf_node - - awq_data[target_node.node_name] = AWQCompressionInfo(weight_params, target_node, merge_node) - alpha_step = (self._alpha_max - self._alpha_min) / self._steps for k, awq_data_item in track(awq_data.items(), description="Applying AWQ"): wp = awq_data_item.weight_params - target_node = awq_data_item.target_node merge_node = awq_data_item.merge_node weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) if len(weight_data) != 1: # not supported by the algorithm @@ -305,6 +260,228 @@ def apply( return transformed_model + def data_aware_step(self, wp, weight, statistics): + alpha_step = (self._alpha_max - self._alpha_min) / self._steps + config = wp.compression_config + s, X = process_stats(statistics, self._subset_size) + + top_k = max(int(s.shape[0] * self._percent_to_apply), 1) + topk_idxs = fns.argsort(-s)[:top_k] + + group_size = config.group_size + if group_size == -1: + group_size = s.shape[0] + + groups_to_correct = set() + for idx in topk_idxs: + groups_to_correct.add(idx.data // group_size) + + groups_to_correct = list(groups_to_correct) + + assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 + reduction_axis = wp.reduction_axes[0] + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + shape_vector = fns.mean(X, axis=1) + scale = fns.ones_like(shape_vector) + + awq_config = deepcopy(config) + awq_config.group_size = -1 + + for gi in groups_to_correct: + offset = gi * group_size + gscale = s[offset : offset + group_size] + + a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) + a_max = 1e2 + gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) + + gweight = weight[:, offset : offset + group_size] + gacts = X[offset : offset + group_size, :] + + fp32_out = fns.matmul(gweight, gacts) + min_diff = fns.max(fns.abs(fp32_out)) + best_scale = None + + alpha = self._alpha_min + for _ in range(self._steps): + cur_scale = gscale**alpha + weights_to_fake_quantize = gweight * cur_scale + if config.mode == CompressWeightsMode.NF4: + g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) + g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) + g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) + else: + g_decompressed_weighs = quantize_dequantize_weight( + weights_to_fake_quantize, awq_config, reduction_axis + ) + sacts = gacts / fns.unsqueeze(cur_scale, 1) + + cur_out = fns.matmul(g_decompressed_weighs, sacts) + cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) + if cur_diff < min_diff: + min_diff = cur_diff + best_scale = cur_scale + alpha += alpha_step + + if best_scale is not None: + scale.data[offset : offset + group_size] = best_scale.data + + return scale + + def data_free_step(self, weight): + eps = fns.finfo(weight).eps + scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) + return 1 / scale + + def apply_data_free( + self, + model: TModel, + graph: NNCFGraph, + all_weight_params: List[WeightCompressionParameters], + nodes_to_compress: List[NNCFNode], + wc_backend_entity: Optional[WeightCompressionAlgoBackend] = None, + ) -> TModel: + """ + Applies the algorithm to the model. + :param model: Model for applying algorithm. + :param graph: Model graph. + :param all_weight_params: List of all weight parameters. + :param nodes_to_compress: List of nodes for processing. + :param wc_backend_entity: Weight compression algorithm backend. + :return: A resulting model. + """ + self._set_backend_entity(model, wc_backend_entity) + + awq_data = self.get_awq_data(graph, all_weight_params, nodes_to_compress) + if len(awq_data) == 0: + return model + + transformation_layout = TransformationLayout() + model_transformer = ModelTransformerFactory.create(model, inplace=True) + for k, awq_data_item in track(awq_data.items(), description="Applying data-free AWQ"): + wp = awq_data_item.weight_params + merge_node = awq_data_item.merge_node + weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) + if len(weight_data) != 1: # not supported by the algorithm + continue + + nncf_logger.debug(f"Apply data-free AWQ for: {wp.node_with_weight.node_name}") + + _, weight_port_id = weight_data[0] + + weight = self._backend_entity.get_weight( + wp.node_with_weight, weight_port_id, model, graph + ) # get_const_value(wp.weight_node) + assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 + reduction_axis = wp.reduction_axes[0] + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + eps = fns.finfo(weight).eps + + scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) + + a_scale = scale + w_scale = scale + if wp.reduction_axes[0] == 0: + w_scale = fns.unsqueeze(1.0 / w_scale, 1) + a_scale = fns.unsqueeze(a_scale, 0) + else: + w_scale = fns.unsqueeze(1 / w_scale, 0) + a_scale = fns.unsqueeze(a_scale, 1) + + scaled_weight = weight * w_scale + self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) + + if self._backend_entity.is_node_with_weights( + merge_node, graph + ): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul + for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph): + merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph) + merge_weight = merge_weight * a_scale + self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight) + a_scale = fns.transpose(a_scale) + else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node + a_scale = fns.transpose(a_scale) + next_nodes = graph.get_next_nodes(merge_node) + source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id + scale_insertion_command = self._backend_entity.scale_insertion_command( + merge_node, next_nodes, source_node_output_port, a_scale.data + ) + transformation_layout.register(scale_insertion_command) + + self._scale_per_target_node[k] = a_scale + + transformed_model = model_transformer.transform(transformation_layout) + + return transformed_model + + def get_awq_data(self, + graph: NNCFGraph, + all_weight_params: List[WeightCompressionParameters], + nodes_to_compress: List[NNCFNode]) -> Dict[str, AWQCompressionInfo]: + """ + Finds awq patterns in graph and returns it. + :param graph: Model graph. + :param all_weight_params: List of all weight parameters. + :param nodes_to_compress: List of nodes for processing. + :return: A dict with node names and matched AWQ patterns. + """ + matches = [] + inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) + nx_graph = inference_nncf_graph.get_nx_graph_copy() + for _, pattern_graph in self._patterns.items(): + matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False)) + + if len(matches) == 0: + nncf_logger.info("No matching patterns were found for applying AWQ algorithm, it will be skipped.") + return {} + + awq_data = {} + name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)} + + for match in matches: + nncf_node = graph.get_node_by_key(match[-1]) + if not self._backend_entity.is_node_with_weights(nncf_node, graph): + continue + + target_node_names = [] + for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): + target_node_names.append(weight_op_friendly_name) + + # skip node if it is in IgnoredScope or should not be compressed + if target_node_names[-1] not in name_mapping: + continue + + weight_params = all_weight_params[name_mapping[target_node_names[-1]]] + + if weight_params.compression_config.num_bits != 4: + continue + target_node = nodes_to_compress[name_mapping[target_node_names[-1]]] + + # avoid matching different patterns for the same node + if target_node.node_name in awq_data: + continue + + nncf_node = graph.get_node_by_key(match[0]) + + if self._backend_entity.is_node_with_weights(nncf_node, graph): # pattern MatMul->Multiply->MatMul + merge_node_names = [] + for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): + merge_node_names.append(weight_op_friendly_name) + merge_node = nodes_to_compress[name_mapping[merge_node_names[-1]]] + else: # pattern Act->MatMul or Act->Multiply->MatMul + merge_node = nncf_node + + awq_data[target_node.node_name] = AWQCompressionInfo(weight_params, target_node, merge_node) + return awq_data + def update_statistics(self, statistics): # Multiply activations by the computed scales for node_name, scale in self._scale_per_target_node.items(): From 19a64ac4ad820f6d2b665ab63807704c587a947c Mon Sep 17 00:00:00 2001 From: Andrei Anufriev Date: Wed, 26 Feb 2025 19:44:01 +0100 Subject: [PATCH 4/8] Data free AWQ. --- .../weight_compression/algorithm.py | 2 +- .../algorithms/weight_compression/awq.py | 189 ++---------------- nncf/quantization/quantize_model.py | 4 +- 3 files changed, 22 insertions(+), 173 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 53f5298d83e..20fa1f4f75a 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -506,7 +506,7 @@ def apply( nodes_to_compress = self.get_nodes_to_compress(graph) statistics = None - if self._data_aware_mixed_precision or self._data_aware_compression: + if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset: matmul_nodes_to_compress = [ node for node in nodes_to_compress if node.metatype in self._backend_entity.matmul_metatypes ] diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index ea4046071dc..55349d4fdfc 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -133,106 +133,37 @@ def apply( """ self._set_backend_entity(model, wc_backend_entity) - awq_data = self.get_awq_data(graph, all_weight_params, nodes_to_compress) + awq_data = self._get_awq_data(graph, all_weight_params, nodes_to_compress) if len(awq_data) == 0: return model transformation_layout = TransformationLayout() model_transformer = ModelTransformerFactory.create(model, inplace=True) - alpha_step = (self._alpha_max - self._alpha_min) / self._steps - - for k, awq_data_item in track(awq_data.items(), description="Applying AWQ"): + is_data_free = statistics is None + description = "Applying data-free AWQ" if is_data_free else "Applying AWQ" + + for k, awq_data_item in track(awq_data.items(), description=description): wp = awq_data_item.weight_params merge_node = awq_data_item.merge_node weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) if len(weight_data) != 1: # not supported by the algorithm continue - nncf_logger.debug(f"Apply AWQ for: {wp.node_with_weight.node_name}") + nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}") _, weight_port_id = weight_data[0] - - config = wp.compression_config - - s, X = process_stats(statistics[k], self._subset_size) - - top_k = max(int(s.shape[0] * self._percent_to_apply), 1) - topk_idxs = fns.argsort(-s)[:top_k] - - group_size = config.group_size - if group_size == -1: - group_size = s.shape[0] - - groups_to_correct = set() - for idx in topk_idxs: - groups_to_correct.add(idx.data // group_size) - - groups_to_correct = list(groups_to_correct) - weight = self._backend_entity.get_weight( wp.node_with_weight, weight_port_id, model, graph ) # get_const_value(wp.weight_node) - assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 - reduction_axis = wp.reduction_axes[0] - - if reduction_axis == 0: - weight = fns.transpose(weight) - reduction_axis = 1 - - shape_vector = fns.mean(X, axis=1) - scale = fns.ones_like(shape_vector) - - awq_config = deepcopy(config) - awq_config.group_size = -1 - - for gi in groups_to_correct: - offset = gi * group_size - gscale = s[offset : offset + group_size] - - a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) - a_max = 1e2 - gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) - - gweight = weight[:, offset : offset + group_size] - gacts = X[offset : offset + group_size, :] - - fp32_out = fns.matmul(gweight, gacts) - min_diff = fns.max(fns.abs(fp32_out)) - best_scale = None - - alpha = self._alpha_min - for _ in range(self._steps): - cur_scale = gscale**alpha - weights_to_fake_quantize = gweight * cur_scale - if config.mode == CompressWeightsMode.NF4: - g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) - g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) - g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) - else: - g_decompressed_weighs = quantize_dequantize_weight( - weights_to_fake_quantize, awq_config, reduction_axis - ) - sacts = gacts / fns.unsqueeze(cur_scale, 1) - - cur_out = fns.matmul(g_decompressed_weighs, sacts) - cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) - if cur_diff < min_diff: - min_diff = cur_diff - best_scale = cur_scale - alpha += alpha_step - - if best_scale is not None: - scale.data[offset : offset + group_size] = best_scale.data - - a_scale = scale - w_scale = scale - if wp.reduction_axes[0] == 0: - w_scale = fns.unsqueeze(w_scale, 1) - a_scale = fns.unsqueeze(1.0 / a_scale, 0) + + if is_data_free: + scale = self._data_free_step(weight) else: - w_scale = fns.unsqueeze(w_scale, 0) - a_scale = fns.unsqueeze(1.0 / a_scale, 1) + scale = self._data_aware_step(wp, weight, statistics[k]) + + w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0]) + a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0]) scaled_weight = weight * w_scale self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) @@ -260,7 +191,7 @@ def apply( return transformed_model - def data_aware_step(self, wp, weight, statistics): + def _data_aware_step(self, wp, weight, statistics): alpha_step = (self._alpha_max - self._alpha_min) / self._steps config = wp.compression_config s, X = process_stats(statistics, self._subset_size) @@ -332,97 +263,12 @@ def data_aware_step(self, wp, weight, statistics): return scale - def data_free_step(self, weight): + def _data_free_step(self, weight): eps = fns.finfo(weight).eps scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) return 1 / scale - def apply_data_free( - self, - model: TModel, - graph: NNCFGraph, - all_weight_params: List[WeightCompressionParameters], - nodes_to_compress: List[NNCFNode], - wc_backend_entity: Optional[WeightCompressionAlgoBackend] = None, - ) -> TModel: - """ - Applies the algorithm to the model. - :param model: Model for applying algorithm. - :param graph: Model graph. - :param all_weight_params: List of all weight parameters. - :param nodes_to_compress: List of nodes for processing. - :param wc_backend_entity: Weight compression algorithm backend. - :return: A resulting model. - """ - self._set_backend_entity(model, wc_backend_entity) - - awq_data = self.get_awq_data(graph, all_weight_params, nodes_to_compress) - if len(awq_data) == 0: - return model - - transformation_layout = TransformationLayout() - model_transformer = ModelTransformerFactory.create(model, inplace=True) - for k, awq_data_item in track(awq_data.items(), description="Applying data-free AWQ"): - wp = awq_data_item.weight_params - merge_node = awq_data_item.merge_node - weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) - if len(weight_data) != 1: # not supported by the algorithm - continue - - nncf_logger.debug(f"Apply data-free AWQ for: {wp.node_with_weight.node_name}") - - _, weight_port_id = weight_data[0] - - weight = self._backend_entity.get_weight( - wp.node_with_weight, weight_port_id, model, graph - ) # get_const_value(wp.weight_node) - assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 - reduction_axis = wp.reduction_axes[0] - - if reduction_axis == 0: - weight = fns.transpose(weight) - reduction_axis = 1 - - eps = fns.finfo(weight).eps - - scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) - - a_scale = scale - w_scale = scale - if wp.reduction_axes[0] == 0: - w_scale = fns.unsqueeze(1.0 / w_scale, 1) - a_scale = fns.unsqueeze(a_scale, 0) - else: - w_scale = fns.unsqueeze(1 / w_scale, 0) - a_scale = fns.unsqueeze(a_scale, 1) - - scaled_weight = weight * w_scale - self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) - - if self._backend_entity.is_node_with_weights( - merge_node, graph - ): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul - for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph): - merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph) - merge_weight = merge_weight * a_scale - self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight) - a_scale = fns.transpose(a_scale) - else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node - a_scale = fns.transpose(a_scale) - next_nodes = graph.get_next_nodes(merge_node) - source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id - scale_insertion_command = self._backend_entity.scale_insertion_command( - merge_node, next_nodes, source_node_output_port, a_scale.data - ) - transformation_layout.register(scale_insertion_command) - - self._scale_per_target_node[k] = a_scale - - transformed_model = model_transformer.transform(transformation_layout) - - return transformed_model - - def get_awq_data(self, + def _get_awq_data(self, graph: NNCFGraph, all_weight_params: List[WeightCompressionParameters], nodes_to_compress: List[NNCFNode]) -> Dict[str, AWQCompressionInfo]: @@ -483,6 +329,9 @@ def get_awq_data(self, return awq_data def update_statistics(self, statistics): + if statistics is None: + return statistics + # Multiply activations by the computed scales for node_name, scale in self._scale_per_target_node.items(): for mean_stat in statistics[node_name].mean_values: diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 7c2126b720b..738aa357502 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -585,11 +585,11 @@ def compress_weights( if backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl - if any((awq, scale_estimation, gptq, lora_correction)) and ( + if any((scale_estimation, gptq, lora_correction)) and ( dataset is None or mode == CompressWeightsMode.E2M1 ): msg = ( - "Scale estimation, AWQ, GPTQ or Lora Correction algorithm is defined, " + "Scale estimation, GPTQ or Lora Correction algorithm is defined, " "but dataset is None or mode is E2M1." ) raise nncf.ParameterNotSupportedError(msg) From bf215d525b8c78394cb2f040502417aff44be0ef Mon Sep 17 00:00:00 2001 From: Andrei Anufriev Date: Wed, 26 Feb 2025 21:08:12 +0100 Subject: [PATCH 5/8] Fixed style. --- .../quantization/algorithms/weight_compression/awq.py | 11 +++++------ nncf/quantization/quantize_model.py | 4 +--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 55349d4fdfc..e3b69340c75 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -142,7 +142,7 @@ def apply( is_data_free = statistics is None description = "Applying data-free AWQ" if is_data_free else "Applying AWQ" - + for k, awq_data_item in track(awq_data.items(), description=description): wp = awq_data_item.weight_params merge_node = awq_data_item.merge_node @@ -268,17 +268,16 @@ def _data_free_step(self, weight): scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) return 1 / scale - def _get_awq_data(self, - graph: NNCFGraph, - all_weight_params: List[WeightCompressionParameters], - nodes_to_compress: List[NNCFNode]) -> Dict[str, AWQCompressionInfo]: + def _get_awq_data( + self, graph: NNCFGraph, all_weight_params: List[WeightCompressionParameters], nodes_to_compress: List[NNCFNode] + ) -> Dict[str, AWQCompressionInfo]: """ Finds awq patterns in graph and returns it. :param graph: Model graph. :param all_weight_params: List of all weight parameters. :param nodes_to_compress: List of nodes for processing. :return: A dict with node names and matched AWQ patterns. - """ + """ matches = [] inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) nx_graph = inference_nncf_graph.get_nx_graph_copy() diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 738aa357502..b2e9370e8ce 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -585,9 +585,7 @@ def compress_weights( if backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl - if any((scale_estimation, gptq, lora_correction)) and ( - dataset is None or mode == CompressWeightsMode.E2M1 - ): + if any((scale_estimation, gptq, lora_correction)) and (dataset is None or mode == CompressWeightsMode.E2M1): msg = ( "Scale estimation, GPTQ or Lora Correction algorithm is defined, " "but dataset is None or mode is E2M1." From 566ebe70ec9646241e843c8abb4b4c307c5987fc Mon Sep 17 00:00:00 2001 From: Andrei Anufriev Date: Thu, 27 Feb 2025 10:18:34 +0100 Subject: [PATCH 6/8] Fixed shape of data item int test. --- tests/openvino/native/quantization/test_weights_compression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 76b9c380e45..2ec50c08b10 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -972,7 +972,7 @@ def test_call_gptq_with_dataset_scale_estimation_neg_group_size(mode): ) def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids): model = SequentialMatmulModel().ov_model - dataset = Dataset([np.ones([1, 4, 4]), np.arange(16).reshape(4, 4)]) + dataset = Dataset([np.ones([1, 4, 4]), np.arange(16).reshape(1, 4, 4)]) compressed_model = compress_weights( model, mode=CompressWeightsMode.E2M1, From 70e47c8e574261d2b9dda1f367a79318af0825dc Mon Sep 17 00:00:00 2001 From: Andrei Anufriev Date: Thu, 27 Feb 2025 11:24:16 +0100 Subject: [PATCH 7/8] Fixed test case for E2M1. --- nncf/quantization/quantize_model.py | 11 ++++++----- .../native/quantization/test_weights_compression.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index b2e9370e8ce..94b5ad06df6 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -585,11 +585,12 @@ def compress_weights( if backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl - if any((scale_estimation, gptq, lora_correction)) and (dataset is None or mode == CompressWeightsMode.E2M1): - msg = ( - "Scale estimation, GPTQ or Lora Correction algorithm is defined, " - "but dataset is None or mode is E2M1." - ) + if any((scale_estimation, gptq, lora_correction)) and dataset is None: + msg = "Scale estimation, GPTQ or Lora Correction algorithm is defined, but dataset is None." + raise nncf.ParameterNotSupportedError(msg) + + if any((awq, scale_estimation, gptq, lora_correction)) and mode == CompressWeightsMode.E2M1: + msg = "AWQ, Scale estimation, GPTQ or Lora Correction algorithm is defined, but mode is E2M1." raise nncf.ParameterNotSupportedError(msg) if gptq and lora_correction: diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 2ec50c08b10..bdfe6ba1dea 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -691,7 +691,6 @@ def test_raise_error_with_unsupported_params_for_e2m1(algo): "algo", ( "lora_correction", - "awq", "scale_estimation", "gptq", ), From 6b3310b4af326bc911bc4f5dd2af4584599b95e0 Mon Sep 17 00:00:00 2001 From: Andrei Anufriev Date: Tue, 4 Mar 2025 11:56:26 +0100 Subject: [PATCH 8/8] Enable awq by default. --- nncf/quantization/algorithms/weight_compression/awq.py | 2 +- nncf/quantization/quantize_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index e3b69340c75..2a5ec7f2dd2 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -140,7 +140,7 @@ def apply( transformation_layout = TransformationLayout() model_transformer = ModelTransformerFactory.create(model, inplace=True) - is_data_free = statistics is None + is_data_free = True #statistics is None description = "Applying data-free AWQ" if is_data_free else "Applying AWQ" for k, awq_data_item in track(awq_data.items(), description=description): diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 94b5ad06df6..bdc1a50e143 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -435,7 +435,7 @@ def compress_weights( sensitivity_metric: Optional[SensitivityMetric] = None, *, subset_size: int = 128, - awq: Optional[bool] = None, + awq: Optional[bool] = True, scale_estimation: Optional[bool] = None, gptq: Optional[bool] = None, lora_correction: Optional[bool] = None,