-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aanuf/data free awq #3315
Draft
andreyanufr
wants to merge
20
commits into
openvinotoolkit:develop
Choose a base branch
from
andreyanufr:aanuf/data_free_awq
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+160
−137
Draft
Aanuf/data free awq #3315
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
488cacc
Support scale estimation inside GPTQ
alexsu52 ee64877
fix for INT4_ASYM
alexsu52 f22e411
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr 51b4d7b
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr f66cd1e
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr 7ce5a53
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr f74d156
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr 5288c79
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr 1becf15
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr 047d7d9
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr c0c7e57
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr b74dea1
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr 26a9a77
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr 25fcc2c
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr f6f4693
Data-free AWQ prototype.
andreyanufr 19a64ac
Data free AWQ.
andreyanufr bf215d5
Fixed style.
andreyanufr 566ebe7
Fixed shape of data item int test.
andreyanufr 70e47c8
Fixed test case for E2M1.
andreyanufr 6b3310b
Enable awq by default.
andreyanufr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -132,19 +132,161 @@ def apply( | |||||
:return: A resulting model. | ||||||
""" | ||||||
self._set_backend_entity(model, wc_backend_entity) | ||||||
matches = [] | ||||||
|
||||||
awq_data = self._get_awq_data(graph, all_weight_params, nodes_to_compress) | ||||||
if len(awq_data) == 0: | ||||||
return model | ||||||
|
||||||
transformation_layout = TransformationLayout() | ||||||
model_transformer = ModelTransformerFactory.create(model, inplace=True) | ||||||
|
||||||
is_data_free = True #statistics is None | ||||||
description = "Applying data-free AWQ" if is_data_free else "Applying AWQ" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
maybe more clear |
||||||
|
||||||
for k, awq_data_item in track(awq_data.items(), description=description): | ||||||
wp = awq_data_item.weight_params | ||||||
merge_node = awq_data_item.merge_node | ||||||
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) | ||||||
if len(weight_data) != 1: # not supported by the algorithm | ||||||
continue | ||||||
|
||||||
nncf_logger.debug(f"{description} for: {wp.node_with_weight.node_name}") | ||||||
|
||||||
_, weight_port_id = weight_data[0] | ||||||
weight = self._backend_entity.get_weight( | ||||||
wp.node_with_weight, weight_port_id, model, graph | ||||||
) # get_const_value(wp.weight_node) | ||||||
|
||||||
if is_data_free: | ||||||
scale = self._data_free_step(weight) | ||||||
else: | ||||||
scale = self._data_aware_step(wp, weight, statistics[k]) | ||||||
|
||||||
w_scale = fns.unsqueeze(scale, 1 - wp.reduction_axes[0]) | ||||||
a_scale = fns.unsqueeze(1.0 / scale, wp.reduction_axes[0]) | ||||||
|
||||||
scaled_weight = weight * w_scale | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please rebase from develop, I've made some casting for float16 models |
||||||
self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) | ||||||
|
||||||
if self._backend_entity.is_node_with_weights( | ||||||
merge_node, graph | ||||||
): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul | ||||||
for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph): | ||||||
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph) | ||||||
merge_weight = merge_weight * a_scale | ||||||
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight) | ||||||
a_scale = fns.transpose(a_scale) | ||||||
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node | ||||||
a_scale = fns.transpose(a_scale) | ||||||
next_nodes = graph.get_next_nodes(merge_node) | ||||||
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id | ||||||
scale_insertion_command = self._backend_entity.scale_insertion_command( | ||||||
merge_node, next_nodes, source_node_output_port, a_scale.data | ||||||
) | ||||||
transformation_layout.register(scale_insertion_command) | ||||||
|
||||||
self._scale_per_target_node[k] = a_scale | ||||||
|
||||||
transformed_model = model_transformer.transform(transformation_layout) | ||||||
|
||||||
return transformed_model | ||||||
|
||||||
def _data_aware_step(self, wp, weight, statistics): | ||||||
alpha_step = (self._alpha_max - self._alpha_min) / self._steps | ||||||
config = wp.compression_config | ||||||
s, X = process_stats(statistics, self._subset_size) | ||||||
|
||||||
top_k = max(int(s.shape[0] * self._percent_to_apply), 1) | ||||||
topk_idxs = fns.argsort(-s)[:top_k] | ||||||
|
||||||
group_size = config.group_size | ||||||
if group_size == -1: | ||||||
group_size = s.shape[0] | ||||||
|
||||||
groups_to_correct = set() | ||||||
for idx in topk_idxs: | ||||||
groups_to_correct.add(idx.data // group_size) | ||||||
|
||||||
groups_to_correct = list(groups_to_correct) | ||||||
|
||||||
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 | ||||||
reduction_axis = wp.reduction_axes[0] | ||||||
|
||||||
if reduction_axis == 0: | ||||||
weight = fns.transpose(weight) | ||||||
reduction_axis = 1 | ||||||
|
||||||
shape_vector = fns.mean(X, axis=1) | ||||||
scale = fns.ones_like(shape_vector) | ||||||
|
||||||
awq_config = deepcopy(config) | ||||||
awq_config.group_size = -1 | ||||||
|
||||||
for gi in groups_to_correct: | ||||||
offset = gi * group_size | ||||||
gscale = s[offset : offset + group_size] | ||||||
|
||||||
a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) | ||||||
a_max = 1e2 | ||||||
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) | ||||||
|
||||||
gweight = weight[:, offset : offset + group_size] | ||||||
gacts = X[offset : offset + group_size, :] | ||||||
|
||||||
fp32_out = fns.matmul(gweight, gacts) | ||||||
min_diff = fns.max(fns.abs(fp32_out)) | ||||||
best_scale = None | ||||||
|
||||||
alpha = self._alpha_min | ||||||
for _ in range(self._steps): | ||||||
cur_scale = gscale**alpha | ||||||
weights_to_fake_quantize = gweight * cur_scale | ||||||
if config.mode == CompressWeightsMode.NF4: | ||||||
g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) | ||||||
g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) | ||||||
g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) | ||||||
else: | ||||||
g_decompressed_weighs = quantize_dequantize_weight( | ||||||
weights_to_fake_quantize, awq_config, reduction_axis | ||||||
) | ||||||
sacts = gacts / fns.unsqueeze(cur_scale, 1) | ||||||
|
||||||
cur_out = fns.matmul(g_decompressed_weighs, sacts) | ||||||
cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) | ||||||
if cur_diff < min_diff: | ||||||
min_diff = cur_diff | ||||||
best_scale = cur_scale | ||||||
alpha += alpha_step | ||||||
|
||||||
if best_scale is not None: | ||||||
scale.data[offset : offset + group_size] = best_scale.data | ||||||
|
||||||
return scale | ||||||
|
||||||
def _data_free_step(self, weight): | ||||||
eps = fns.finfo(weight).eps | ||||||
scale = fns.maximum(fns.mean(fns.abs(weight), axis=0), eps) | ||||||
return 1 / scale | ||||||
|
||||||
def _get_awq_data( | ||||||
self, graph: NNCFGraph, all_weight_params: List[WeightCompressionParameters], nodes_to_compress: List[NNCFNode] | ||||||
) -> Dict[str, AWQCompressionInfo]: | ||||||
""" | ||||||
Finds awq patterns in graph and returns it. | ||||||
:param graph: Model graph. | ||||||
:param all_weight_params: List of all weight parameters. | ||||||
:param nodes_to_compress: List of nodes for processing. | ||||||
:return: A dict with node names and matched AWQ patterns. | ||||||
""" | ||||||
matches = [] | ||||||
inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], []) | ||||||
nx_graph = inference_nncf_graph.get_nx_graph_copy() | ||||||
for _, pattern_graph in self._patterns.items(): | ||||||
matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False)) | ||||||
|
||||||
if len(matches) == 0: | ||||||
nncf_logger.info("No matching patterns were found for applying AWQ algorithm, it will be skipped.") | ||||||
return model | ||||||
|
||||||
transformation_layout = TransformationLayout() | ||||||
model_transformer = ModelTransformerFactory.create(model, inplace=True) | ||||||
return {} | ||||||
|
||||||
awq_data = {} | ||||||
name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)} | ||||||
|
@@ -183,129 +325,12 @@ def apply( | |||||
merge_node = nncf_node | ||||||
|
||||||
awq_data[target_node.node_name] = AWQCompressionInfo(weight_params, target_node, merge_node) | ||||||
|
||||||
alpha_step = (self._alpha_max - self._alpha_min) / self._steps | ||||||
|
||||||
for k, awq_data_item in track(awq_data.items(), description="Applying AWQ"): | ||||||
wp = awq_data_item.weight_params | ||||||
target_node = awq_data_item.target_node | ||||||
merge_node = awq_data_item.merge_node | ||||||
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) | ||||||
if len(weight_data) != 1: # not supported by the algorithm | ||||||
continue | ||||||
|
||||||
nncf_logger.debug(f"Apply AWQ for: {wp.node_with_weight.node_name}") | ||||||
|
||||||
_, weight_port_id = weight_data[0] | ||||||
|
||||||
config = wp.compression_config | ||||||
|
||||||
s, X = process_stats(statistics[k], self._subset_size) | ||||||
|
||||||
top_k = max(int(s.shape[0] * self._percent_to_apply), 1) | ||||||
topk_idxs = fns.argsort(-s)[:top_k] | ||||||
|
||||||
group_size = config.group_size | ||||||
if group_size == -1: | ||||||
group_size = s.shape[0] | ||||||
|
||||||
groups_to_correct = set() | ||||||
for idx in topk_idxs: | ||||||
groups_to_correct.add(idx.data // group_size) | ||||||
|
||||||
groups_to_correct = list(groups_to_correct) | ||||||
|
||||||
weight = self._backend_entity.get_weight( | ||||||
wp.node_with_weight, weight_port_id, model, graph | ||||||
) # get_const_value(wp.weight_node) | ||||||
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1 | ||||||
reduction_axis = wp.reduction_axes[0] | ||||||
|
||||||
if reduction_axis == 0: | ||||||
weight = fns.transpose(weight) | ||||||
reduction_axis = 1 | ||||||
|
||||||
shape_vector = fns.mean(X, axis=1) | ||||||
scale = fns.ones_like(shape_vector) | ||||||
|
||||||
awq_config = deepcopy(config) | ||||||
awq_config.group_size = -1 | ||||||
|
||||||
for gi in groups_to_correct: | ||||||
offset = gi * group_size | ||||||
gscale = s[offset : offset + group_size] | ||||||
|
||||||
a_min = fns.astype(fns.quantile(gscale, 0.1), TensorDataType.float32) | ||||||
a_max = 1e2 | ||||||
gscale = fns.clip(gscale, a_min=a_min, a_max=a_max) | ||||||
|
||||||
gweight = weight[:, offset : offset + group_size] | ||||||
gacts = X[offset : offset + group_size, :] | ||||||
|
||||||
fp32_out = fns.matmul(gweight, gacts) | ||||||
min_diff = fns.max(fns.abs(fp32_out)) | ||||||
best_scale = None | ||||||
|
||||||
alpha = self._alpha_min | ||||||
for _ in range(self._steps): | ||||||
cur_scale = gscale**alpha | ||||||
weights_to_fake_quantize = gweight * cur_scale | ||||||
if config.mode == CompressWeightsMode.NF4: | ||||||
g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis) | ||||||
g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) | ||||||
g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) | ||||||
else: | ||||||
g_decompressed_weighs = quantize_dequantize_weight( | ||||||
weights_to_fake_quantize, awq_config, reduction_axis | ||||||
) | ||||||
sacts = gacts / fns.unsqueeze(cur_scale, 1) | ||||||
|
||||||
cur_out = fns.matmul(g_decompressed_weighs, sacts) | ||||||
cur_diff = fns.mean(fns.abs(cur_out - fp32_out)) | ||||||
if cur_diff < min_diff: | ||||||
min_diff = cur_diff | ||||||
best_scale = cur_scale | ||||||
alpha += alpha_step | ||||||
|
||||||
if best_scale is not None: | ||||||
scale.data[offset : offset + group_size] = best_scale.data | ||||||
|
||||||
a_scale = scale | ||||||
w_scale = scale | ||||||
if wp.reduction_axes[0] == 0: | ||||||
w_scale = fns.unsqueeze(w_scale, 1) | ||||||
a_scale = fns.unsqueeze(1.0 / a_scale, 0) | ||||||
else: | ||||||
w_scale = fns.unsqueeze(w_scale, 0) | ||||||
a_scale = fns.unsqueeze(1.0 / a_scale, 1) | ||||||
|
||||||
scaled_weight = weight * w_scale | ||||||
self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight) | ||||||
|
||||||
if self._backend_entity.is_node_with_weights( | ||||||
merge_node, graph | ||||||
): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul | ||||||
for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph): | ||||||
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph) | ||||||
merge_weight = merge_weight * a_scale | ||||||
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight) | ||||||
a_scale = fns.transpose(a_scale) | ||||||
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node | ||||||
a_scale = fns.transpose(a_scale) | ||||||
next_nodes = graph.get_next_nodes(merge_node) | ||||||
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id | ||||||
scale_insertion_command = self._backend_entity.scale_insertion_command( | ||||||
merge_node, next_nodes, source_node_output_port, a_scale.data | ||||||
) | ||||||
transformation_layout.register(scale_insertion_command) | ||||||
|
||||||
self._scale_per_target_node[k] = a_scale | ||||||
|
||||||
transformed_model = model_transformer.transform(transformation_layout) | ||||||
|
||||||
return transformed_model | ||||||
return awq_data | ||||||
|
||||||
def update_statistics(self, statistics): | ||||||
if statistics is None: | ||||||
return statistics | ||||||
|
||||||
# Multiply activations by the computed scales | ||||||
for node_name, scale in self._scale_per_target_node.items(): | ||||||
for mean_stat in statistics[node_name].mean_values: | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd redefine