diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 53f5298d83e..20fa1f4f75a 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -506,7 +506,7 @@ def apply( nodes_to_compress = self.get_nodes_to_compress(graph) statistics = None - if self._data_aware_mixed_precision or self._data_aware_compression: + if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset: matmul_nodes_to_compress = [ node for node in nodes_to_compress if node.metatype in self._backend_entity.matmul_metatypes ] diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 0b37a20c92e..2a5ec7f2dd2 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -132,8 +132,153 @@ def apply( :return: A resulting model. """ self._set_backend_entity(model, wc_backend_entity) - matches = [] + awq_data = self._get_awq_data(graph, all_weight_params, nodes_to_compress) + if len(awq_data) == 0: + return model + + transformation_layout = TransformationLayout() + model_transformer = ModelTransformerFactory.create(model, inplace=True) + + is_data_free = True #statistics is None + description = "Applying data-free AWQ" if is_data_free else "Applying AWQ" + + for k, awq_data_item in track(awq_data.items(), description=description): + wp = awq_data_item.weight_params + merge_node = awq_data_item.merge_node + weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) + if len(weight_data) != 1: # not supported by the algorithm + continue + + nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}") + + _, weight_port_id = weight_data[0] + weight = self._backend_entity.get_weight( + wp.node_with_weight, weight_port_id, model, graph + ) # get_const_value(wp.weight_node) + + if is_data_free: + scale = self._data_free_step(weight) + else: + scale = self._data_aware_step(wp, weight, statistics[k]) + + w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0]) + a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0]) + + scaled_weight = weight * w_scale + self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) + + if self._backend_entity.is_node_with_weights( + merge_node, graph + ): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul + for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph): + merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph) + merge_weight = merge_weight * a_scale + self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight) + a_scale = fns.transpose(a_scale) + else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node + a_scale = fns.transpose(a_scale) + next_nodes = graph.get_next_nodes(merge_node) + source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id + scale_insertion_command = self._backend_entity.scale_insertion_command( + merge_node, next_nodes, source_node_output_port, a_scale.data + ) + transformation_layout.register(scale_insertion_command) + + self._scale_per_target_node[k] = a_scale + + transformed_model = model_transformer.transform(transformation_layout) + + return transformed_model + + def _data_aware_step(self, wp, weight, statistics): + alpha_step = (self._alpha_max - self._alpha_min) / self._steps + config = wp.compression_config + s, X = process_stats(statistics, self._subset_size) + + top_k = max(int(s.shape[0] * self._percent_to_apply), 1) + topk_idxs = fns.argsort(-s)[:top_k] + + group_size = config.group_size + if group_size == -1: + group_size = s.shape[0] + + groups_to_correct = set() + for idx in topk_idxs: + groups_to_correct.add(idx.data // group_size) + + groups_to_correct = list(groups_to_correct) + + assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 + reduction_axis = wp.reduction_axes[0] + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + shape_vector = fns.mean(X, axis=1) + scale = fns.ones_like(shape_vector) + + awq_config = deepcopy(config) + awq_config.group_size = -1 + + for gi in groups_to_correct: + offset = gi * group_size + gscale = s[offset : offset + group_size] + + a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) + a_max = 1e2 + gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) + + gweight = weight[:, offset : offset + group_size] + gacts = X[offset : offset + group_size, :] + + fp32_out = fns.matmul(gweight, gacts) + min_diff = fns.max(fns.abs(fp32_out)) + best_scale = None + + alpha = self._alpha_min + for _ in range(self._steps): + cur_scale = gscale**alpha + weights_to_fake_quantize = gweight * cur_scale + if config.mode == CompressWeightsMode.NF4: + g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) + g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) + g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) + else: + g_decompressed_weighs = quantize_dequantize_weight( + weights_to_fake_quantize, awq_config, reduction_axis + ) + sacts = gacts / fns.unsqueeze(cur_scale, 1) + + cur_out = fns.matmul(g_decompressed_weighs, sacts) + cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) + if cur_diff < min_diff: + min_diff = cur_diff + best_scale = cur_scale + alpha += alpha_step + + if best_scale is not None: + scale.data[offset : offset + group_size] = best_scale.data + + return scale + + def _data_free_step(self, weight): + eps = fns.finfo(weight).eps + scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) + return 1 / scale + + def _get_awq_data( + self, graph: NNCFGraph, all_weight_params: List[WeightCompressionParameters], nodes_to_compress: List[NNCFNode] + ) -> Dict[str, AWQCompressionInfo]: + """ + Finds awq patterns in graph and returns it. + :param graph: Model graph. + :param all_weight_params: List of all weight parameters. + :param nodes_to_compress: List of nodes for processing. + :return: A dict with node names and matched AWQ patterns. + """ + matches = [] inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) nx_graph = inference_nncf_graph.get_nx_graph_copy() for _, pattern_graph in self._patterns.items(): @@ -141,10 +286,7 @@ def apply( if len(matches) == 0: nncf_logger.info("No matching patterns were found for applying AWQ algorithm, it will be skipped.") - return model - - transformation_layout = TransformationLayout() - model_transformer = ModelTransformerFactory.create(model, inplace=True) + return {} awq_data = {} name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)} @@ -183,129 +325,12 @@ def apply( merge_node = nncf_node awq_data[target_node.node_name] = AWQCompressionInfo(weight_params, target_node, merge_node) - - alpha_step = (self._alpha_max - self._alpha_min) / self._steps - - for k, awq_data_item in track(awq_data.items(), description="Applying AWQ"): - wp = awq_data_item.weight_params - target_node = awq_data_item.target_node - merge_node = awq_data_item.merge_node - weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) - if len(weight_data) != 1: # not supported by the algorithm - continue - - nncf_logger.debug(f"Apply AWQ for: {wp.node_with_weight.node_name}") - - _, weight_port_id = weight_data[0] - - config = wp.compression_config - - s, X = process_stats(statistics[k], self._subset_size) - - top_k = max(int(s.shape[0] * self._percent_to_apply), 1) - topk_idxs = fns.argsort(-s)[:top_k] - - group_size = config.group_size - if group_size == -1: - group_size = s.shape[0] - - groups_to_correct = set() - for idx in topk_idxs: - groups_to_correct.add(idx.data // group_size) - - groups_to_correct = list(groups_to_correct) - - weight = self._backend_entity.get_weight( - wp.node_with_weight, weight_port_id, model, graph - ) # get_const_value(wp.weight_node) - assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 - reduction_axis = wp.reduction_axes[0] - - if reduction_axis == 0: - weight = fns.transpose(weight) - reduction_axis = 1 - - shape_vector = fns.mean(X, axis=1) - scale = fns.ones_like(shape_vector) - - awq_config = deepcopy(config) - awq_config.group_size = -1 - - for gi in groups_to_correct: - offset = gi * group_size - gscale = s[offset : offset + group_size] - - a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) - a_max = 1e2 - gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) - - gweight = weight[:, offset : offset + group_size] - gacts = X[offset : offset + group_size, :] - - fp32_out = fns.matmul(gweight, gacts) - min_diff = fns.max(fns.abs(fp32_out)) - best_scale = None - - alpha = self._alpha_min - for _ in range(self._steps): - cur_scale = gscale**alpha - weights_to_fake_quantize = gweight * cur_scale - if config.mode == CompressWeightsMode.NF4: - g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) - g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) - g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) - else: - g_decompressed_weighs = quantize_dequantize_weight( - weights_to_fake_quantize, awq_config, reduction_axis - ) - sacts = gacts / fns.unsqueeze(cur_scale, 1) - - cur_out = fns.matmul(g_decompressed_weighs, sacts) - cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) - if cur_diff < min_diff: - min_diff = cur_diff - best_scale = cur_scale - alpha += alpha_step - - if best_scale is not None: - scale.data[offset : offset + group_size] = best_scale.data - - a_scale = scale - w_scale = scale - if wp.reduction_axes[0] == 0: - w_scale = fns.unsqueeze(w_scale, 1) - a_scale = fns.unsqueeze(1.0 / a_scale, 0) - else: - w_scale = fns.unsqueeze(w_scale, 0) - a_scale = fns.unsqueeze(1.0 / a_scale, 1) - - scaled_weight = weight * w_scale - self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) - - if self._backend_entity.is_node_with_weights( - merge_node, graph - ): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul - for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph): - merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph) - merge_weight = merge_weight * a_scale - self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight) - a_scale = fns.transpose(a_scale) - else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node - a_scale = fns.transpose(a_scale) - next_nodes = graph.get_next_nodes(merge_node) - source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id - scale_insertion_command = self._backend_entity.scale_insertion_command( - merge_node, next_nodes, source_node_output_port, a_scale.data - ) - transformation_layout.register(scale_insertion_command) - - self._scale_per_target_node[k] = a_scale - - transformed_model = model_transformer.transform(transformation_layout) - - return transformed_model + return awq_data def update_statistics(self, statistics): + if statistics is None: + return statistics + # Multiply activations by the computed scales for node_name, scale in self._scale_per_target_node.items(): for mean_stat in statistics[node_name].mean_values: diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 7c2126b720b..bdc1a50e143 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -435,7 +435,7 @@ def compress_weights( sensitivity_metric: Optional[SensitivityMetric] = None, *, subset_size: int = 128, - awq: Optional[bool] = None, + awq: Optional[bool] = True, scale_estimation: Optional[bool] = None, gptq: Optional[bool] = None, lora_correction: Optional[bool] = None, @@ -585,13 +585,12 @@ def compress_weights( if backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl - if any((awq, scale_estimation, gptq, lora_correction)) and ( - dataset is None or mode == CompressWeightsMode.E2M1 - ): - msg = ( - "Scale estimation, AWQ, GPTQ or Lora Correction algorithm is defined, " - "but dataset is None or mode is E2M1." - ) + if any((scale_estimation, gptq, lora_correction)) and dataset is None: + msg = "Scale estimation, GPTQ or Lora Correction algorithm is defined, but dataset is None." + raise nncf.ParameterNotSupportedError(msg) + + if any((awq, scale_estimation, gptq, lora_correction)) and mode == CompressWeightsMode.E2M1: + msg = "AWQ, Scale estimation, GPTQ or Lora Correction algorithm is defined, but mode is E2M1." raise nncf.ParameterNotSupportedError(msg) if gptq and lora_correction: diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 76b9c380e45..bdfe6ba1dea 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -691,7 +691,6 @@ def test_raise_error_with_unsupported_params_for_e2m1(algo): "algo", ( "lora_correction", - "awq", "scale_estimation", "gptq", ), @@ -972,7 +971,7 @@ def test_call_gptq_with_dataset_scale_estimation_neg_group_size(mode): ) def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids): model = SequentialMatmulModel().ov_model - dataset = Dataset([np.ones([1, 4, 4]), np.arange(16).reshape(4, 4)]) + dataset = Dataset([np.ones([1, 4, 4]), np.arange(16).reshape(1, 4, 4)]) compressed_model = compress_weights( model, mode=CompressWeightsMode.E2M1,