Skip to content

Commit d5de30d

Browse files
authored
Fixed weight compression for float16/bfloat16 Torch models (#3330)
### Changes Adapted AWQ and Scale Estimation algorithm for the case when weights and activations are float16 and bfloat16. ### Reason for changes Otherwise, compression fails with errors like that: ` RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float ` ### Related tickets n/a ### Tests - tests/torch/ptq/test_weights_compression.py::test_half_precision_models - PTWC https://github.com/openvinotoolkit/nncf/actions/runs/13680175191 - PTWC Performance | 51 build on develop | 52 build on PR | :-------------------------:|:-------------------------: ![image](https://github.com/user-attachments/assets/3d5b9c96-cf4c-47a2-89e8-f1b6b4f48113) | ![image](https://github.com/user-attachments/assets/0b943e42-f407-4cd5-9e4a-3215561ea9e9)
1 parent 59c978f commit d5de30d

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

nncf/quantization/algorithms/weight_compression/awq.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def apply(
201201
config = wp.compression_config
202202

203203
s, X = process_stats(statistics[k], self._subset_size)
204+
s = s.astype(TensorDataType.float32)
205+
X = X.astype(TensorDataType.float32)
204206

205207
top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
206208
topk_idxs = fns.argsort(-s)[:top_k]
@@ -218,6 +220,8 @@ def apply(
218220
weight = self._backend_entity.get_weight(
219221
wp.node_with_weight, weight_port_id, model, graph
220222
) # get_const_value(wp.weight_node)
223+
weight_dtype = weight.dtype
224+
weight = weight.astype(TensorDataType.float32)
221225
assert isinstance(wp.reduction_axes, tuple) and len(wp.reduction_axes) == 1
222226
reduction_axis = wp.reduction_axes[0]
223227

@@ -279,19 +283,19 @@ def apply(
279283
w_scale = fns.unsqueeze(w_scale, 0)
280284
a_scale = fns.unsqueeze(1.0 / a_scale, 1)
281285

282-
scaled_weight = weight * w_scale
286+
scaled_weight = (weight * w_scale).astype(weight_dtype)
283287
self._backend_entity.set_weight(wp.node_with_weight, weight_port_id, model, graph, scaled_weight)
284288

285289
if self._backend_entity.is_node_with_weights(
286290
merge_node, graph
287291
): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul
288292
for _, port_id in self._backend_entity.get_weight_names_and_port_ids(merge_node, graph):
289293
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
290-
merge_weight = merge_weight * a_scale
294+
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
291295
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
292296
a_scale = fns.transpose(a_scale)
293297
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
294-
a_scale = fns.transpose(a_scale)
298+
a_scale = fns.transpose(a_scale).astype(weight_dtype)
295299
next_nodes = graph.get_next_nodes(merge_node)
296300
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
297301
scale_insertion_command = self._backend_entity.scale_insertion_command(

nncf/quantization/algorithms/weight_compression/scale_estimation.py

+1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def calculate_quantization_params(
185185

186186
s, X = process_stats(statistics, subset_size)
187187

188+
X = X.astype(TensorDataType.float32)
188189
weight = weight.astype(TensorDataType.float32)
189190
eps = fns.finfo(weight).eps
190191

tests/torch/ptq/test_weights_compression.py

+18
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18+
from transformers import AutoModelForCausalLM
19+
from transformers import AutoTokenizer
1820

1921
import nncf
2022
from nncf import BackupMode
@@ -436,6 +438,22 @@ def test_pack_int4():
436438
assert torch.all(unpacked_w == w_int8)
437439

438440

441+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
442+
def test_half_precision_models(dtype):
443+
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
444+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
445+
tokenizer = AutoTokenizer.from_pretrained(model_id)
446+
inputs = tokenizer("dummy_input", return_tensors="pt")
447+
compress_weights(
448+
model,
449+
group_size=2,
450+
mode=CompressWeightsMode.INT4_SYM,
451+
scale_estimation=True,
452+
awq=True,
453+
dataset=nncf.Dataset([dict(inputs)]),
454+
)
455+
456+
439457
class TestPTTemplateWeightCompression(TemplateWeightCompression):
440458
@staticmethod
441459
def get_matmul_model() -> torch.nn.Module:

0 commit comments

Comments
 (0)