Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aanuf/data free awq #3315

Draft
wants to merge 20 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
488cacc
Support scale estimation inside GPTQ
alexsu52 Jun 10, 2024
ee64877
fix for INT4_ASYM
alexsu52 Sep 4, 2024
f22e411
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 23, 2024
51b4d7b
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 26, 2024
f66cd1e
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 30, 2024
7ce5a53
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Oct 2, 2024
f74d156
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
5288c79
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
1becf15
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 14, 2024
047d7d9
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 10, 2024
c0c7e57
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 16, 2024
b74dea1
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 27, 2024
26a9a77
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jan 7, 2025
25fcc2c
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Feb 25, 2025
f6f4693
Data-free AWQ prototype.
andreyanufr Feb 25, 2025
19a64ac
Data free AWQ.
andreyanufr Feb 26, 2025
bf215d5
Fixed style.
andreyanufr Feb 26, 2025
566ebe7
Fixed shape of data item int test.
andreyanufr Feb 27, 2025
70e47c8
Fixed test case for E2M1.
andreyanufr Feb 27, 2025
6b3310b
Enable awq by default.
andreyanufr Mar 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def apply(
nodes_to_compress = self.get_nodes_to_compress(graph)

statistics = None
if self._data_aware_mixed_precision or self._data_aware_compression:
if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd redefine

self._data_aware_compression = (self._awq and dataset) or (...)

matmul_nodes_to_compress = [
node for node in nodes_to_compress if node.metatype in self._backend_entity.matmul_metatypes
]
Expand Down
277 changes: 151 additions & 126 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,161 @@ def apply(
:return: A resulting model.
"""
self._set_backend_entity(model, wc_backend_entity)
matches = []

awq_data = self._get_awq_data(graph, all_weight_params, nodes_to_compress)
if len(awq_data) == 0:
return model

transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model, inplace=True)

is_data_free = True #statistics is None
description = "Applying data-free AWQ" if is_data_free else "Applying AWQ"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
description = "Applying data-free AWQ" if is_data_free else "Applying AWQ"
description = "Applying data-free AWQ" if is_data_free else "Applying data-aware AWQ"

maybe more clear


for k, awq_data_item in track(awq_data.items(), description=description):
wp = awq_data_item.weight_params
merge_node = awq_data_item.merge_node
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
if len(weight_data) != 1: # not supported by the algorithm
continue

nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}")

_, weight_port_id = weight_data[0]
weight = self._backend_entity.get_weight(
wp.node_with_weight, weight_port_id, model, graph
) # get_const_value(wp.weight_node)

if is_data_free:
scale = self._data_free_step(weight)
else:
scale = self._data_aware_step(wp, weight, statistics[k])

w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0])
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0])

scaled_weight = weight * w_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase from develop, I've made some casting for float16 models

self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight)

if self._backend_entity.is_node_with_weights(
merge_node, graph
): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul
for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph):
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
merge_weight = merge_weight * a_scale
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
a_scale = fns.transpose(a_scale)
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
a_scale = fns.transpose(a_scale)
next_nodes = graph.get_next_nodes(merge_node)
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
scale_insertion_command = self._backend_entity.scale_insertion_command(
merge_node, next_nodes, source_node_output_port, a_scale.data
)
transformation_layout.register(scale_insertion_command)

self._scale_per_target_node[k] = a_scale

transformed_model = model_transformer.transform(transformation_layout)

return transformed_model

def _data_aware_step(self, wp, weight, statistics):
alpha_step = (self._alpha_max - self._alpha_min) / self._steps
config = wp.compression_config
s, X = process_stats(statistics, self._subset_size)

top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]

group_size = config.group_size
if group_size == -1:
group_size = s.shape[0]

groups_to_correct = set()
for idx in topk_idxs:
groups_to_correct.add(idx.data // group_size)

groups_to_correct = list(groups_to_correct)

assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
reduction_axis = wp.reduction_axes[0]

if reduction_axis == 0:
weight = fns.transpose(weight)
reduction_axis = 1

shape_vector = fns.mean(X, axis=1)
scale = fns.ones_like(shape_vector)

awq_config = deepcopy(config)
awq_config.group_size = -1

for gi in groups_to_correct:
offset = gi * group_size
gscale = s[offset : offset + group_size]

a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32)
a_max = 1e2
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)

gweight = weight[:, offset : offset + group_size]
gacts = X[offset : offset + group_size, :]

fp32_out = fns.matmul(gweight, gacts)
min_diff = fns.max(fns.abs(fp32_out))
best_scale = None

alpha = self._alpha_min
for _ in range(self._steps):
cur_scale = gscale**alpha
weights_to_fake_quantize = gweight * cur_scale
if config.mode == CompressWeightsMode.NF4:
g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis)
g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale)
g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale)
else:
g_decompressed_weighs = quantize_dequantize_weight(
weights_to_fake_quantize, awq_config, reduction_axis
)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
cur_diff = fns.mean(fns.abs(cur_out - fp32_out))
if cur_diff < min_diff:
min_diff = cur_diff
best_scale = cur_scale
alpha += alpha_step

if best_scale is not None:
scale.data[offset : offset + group_size] = best_scale.data

return scale

def _data_free_step(self, weight):
eps = fns.finfo(weight).eps
scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps)
return 1 / scale

def _get_awq_data(
self, graph: NNCFGraph, all_weight_params: List[WeightCompressionParameters], nodes_to_compress: List[NNCFNode]
) -> Dict[str, AWQCompressionInfo]:
"""
Finds awq patterns in graph and returns it.
:param graph: Model graph.
:param all_weight_params: List of all weight parameters.
:param nodes_to_compress: List of nodes for processing.
:return: A dict with node names and matched AWQ patterns.
"""
matches = []
inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], [])
nx_graph = inference_nncf_graph.get_nx_graph_copy()
for _, pattern_graph in self._patterns.items():
matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False))

if len(matches) == 0:
nncf_logger.info("No matching patterns were found for applying AWQ algorithm, it will be skipped.")
return model

transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model, inplace=True)
return {}

awq_data = {}
name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)}
Expand Down Expand Up @@ -183,129 +325,12 @@ def apply(
merge_node = nncf_node

awq_data[target_node.node_name] = AWQCompressionInfo(weight_params, target_node, merge_node)

alpha_step = (self._alpha_max - self._alpha_min) / self._steps

for k, awq_data_item in track(awq_data.items(), description="Applying AWQ"):
wp = awq_data_item.weight_params
target_node = awq_data_item.target_node
merge_node = awq_data_item.merge_node
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
if len(weight_data) != 1: # not supported by the algorithm
continue

nncf_logger.debug(f"Apply AWQ for: {wp.node_with_weight.node_name}")

_, weight_port_id = weight_data[0]

config = wp.compression_config

s, X = process_stats(statistics[k], self._subset_size)

top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]

group_size = config.group_size
if group_size == -1:
group_size = s.shape[0]

groups_to_correct = set()
for idx in topk_idxs:
groups_to_correct.add(idx.data // group_size)

groups_to_correct = list(groups_to_correct)

weight = self._backend_entity.get_weight(
wp.node_with_weight, weight_port_id, model, graph
) # get_const_value(wp.weight_node)
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
reduction_axis = wp.reduction_axes[0]

if reduction_axis == 0:
weight = fns.transpose(weight)
reduction_axis = 1

shape_vector = fns.mean(X, axis=1)
scale = fns.ones_like(shape_vector)

awq_config = deepcopy(config)
awq_config.group_size = -1

for gi in groups_to_correct:
offset = gi * group_size
gscale = s[offset : offset + group_size]

a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32)
a_max = 1e2
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max)

gweight = weight[:, offset : offset + group_size]
gacts = X[offset : offset + group_size, :]

fp32_out = fns.matmul(gweight, gacts)
min_diff = fns.max(fns.abs(fp32_out))
best_scale = None

alpha = self._alpha_min
for _ in range(self._steps):
cur_scale = gscale**alpha
weights_to_fake_quantize = gweight * cur_scale
if config.mode == CompressWeightsMode.NF4:
g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis)
g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale)
g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale)
else:
g_decompressed_weighs = quantize_dequantize_weight(
weights_to_fake_quantize, awq_config, reduction_axis
)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
cur_diff = fns.mean(fns.abs(cur_out - fp32_out))
if cur_diff < min_diff:
min_diff = cur_diff
best_scale = cur_scale
alpha += alpha_step

if best_scale is not None:
scale.data[offset : offset + group_size] = best_scale.data

a_scale = scale
w_scale = scale
if wp.reduction_axes[0] == 0:
w_scale = fns.unsqueeze(w_scale, 1)
a_scale = fns.unsqueeze(1.0 / a_scale, 0)
else:
w_scale = fns.unsqueeze(w_scale, 0)
a_scale = fns.unsqueeze(1.0 / a_scale, 1)

scaled_weight = weight * w_scale
self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight)

if self._backend_entity.is_node_with_weights(
merge_node, graph
): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul
for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph):
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
merge_weight = merge_weight * a_scale
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
a_scale = fns.transpose(a_scale)
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
a_scale = fns.transpose(a_scale)
next_nodes = graph.get_next_nodes(merge_node)
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
scale_insertion_command = self._backend_entity.scale_insertion_command(
merge_node, next_nodes, source_node_output_port, a_scale.data
)
transformation_layout.register(scale_insertion_command)

self._scale_per_target_node[k] = a_scale

transformed_model = model_transformer.transform(transformation_layout)

return transformed_model
return awq_data

def update_statistics(self, statistics):
if statistics is None:
return statistics

# Multiply activations by the computed scales
for node_name, scale in self._scale_per_target_node.items():
for mean_stat in statistics[node_name].mean_values:
Expand Down
15 changes: 7 additions & 8 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def compress_weights(
sensitivity_metric: Optional[SensitivityMetric] = None,
*,
subset_size: int = 128,
awq: Optional[bool] = None,
awq: Optional[bool] = True,
scale_estimation: Optional[bool] = None,
gptq: Optional[bool] = None,
lora_correction: Optional[bool] = None,
Expand Down Expand Up @@ -585,13 +585,12 @@ def compress_weights(
if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl

if any((awq, scale_estimation, gptq, lora_correction)) and (
dataset is None or mode == CompressWeightsMode.E2M1
):
msg = (
"Scale estimation, AWQ, GPTQ or Lora Correction algorithm is defined, "
"but dataset is None or mode is E2M1."
)
if any((scale_estimation, gptq, lora_correction)) and dataset is None:
msg = "Scale estimation, GPTQ or Lora Correction algorithm is defined, but dataset is None."
raise nncf.ParameterNotSupportedError(msg)

if any((awq, scale_estimation, gptq, lora_correction)) and mode == CompressWeightsMode.E2M1:
msg = "AWQ, Scale estimation, GPTQ or Lora Correction algorithm is defined, but mode is E2M1."
raise nncf.ParameterNotSupportedError(msg)

if gptq and lora_correction:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,6 @@ def test_raise_error_with_unsupported_params_for_e2m1(algo):
"algo",
(
"lora_correction",
"awq",
"scale_estimation",
"gptq",
),
Expand Down Expand Up @@ -972,7 +971,7 @@ def test_call_gptq_with_dataset_scale_estimation_neg_group_size(mode):
)
def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids):
model = SequentialMatmulModel().ov_model
dataset = Dataset([np.ones([1, 4, 4]), np.arange(16).reshape(4, 4)])
dataset = Dataset([np.ones([1, 4, 4]), np.arange(16).reshape(1, 4, 4)])
compressed_model = compress_weights(
model,
mode=CompressWeightsMode.E2M1,
Expand Down
Loading