Skip to content

Commit 7dc55a2

Browse files
Added optimized nf4 quantization
1 parent dd94ac5 commit 7dc55a2

File tree

17 files changed

+365
-156
lines changed

17 files changed

+365
-156
lines changed

nncf/openvino/optimized_functions/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212
from nncf.openvino.optimized_functions.functions import astype as astype
13+
from nncf.openvino.optimized_functions.functions import do_float_quantization as do_float_quantization
1314
from nncf.openvino.optimized_functions.functions import do_integer_quantization as do_integer_quantization
1415
from nncf.openvino.optimized_functions.functions import get_integer_quantization_error as get_integer_quantization_error
1516
from nncf.openvino.optimized_functions.functions import (

nncf/openvino/optimized_functions/functions.py

+44
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
from typing import Optional, Tuple, Union
1313

14+
from nncf import CompressWeightsMode
1415
from nncf.common.utils.caching import disable_results_caching
1516
from nncf.openvino.optimized_functions.models import OV_MODEL_CACHE
1617
from nncf.openvino.optimized_functions.models import OVModelParameters
1718
from nncf.openvino.optimized_functions.models import get_astype_model
19+
from nncf.openvino.optimized_functions.models import get_float_quantization_model
1820
from nncf.openvino.optimized_functions.models import get_integer_quantization_error_model
1921
from nncf.openvino.optimized_functions.models import get_integer_quantization_model
2022
from nncf.openvino.optimized_functions.models import get_integer_quantize_dequantize_weight_model
@@ -97,6 +99,48 @@ def do_integer_quantization(
9799
return compressed_weight, scale, zero_point
98100

99101

102+
def do_float_quantization(
103+
weight: Tensor,
104+
config: WeightCompressionConfig,
105+
reduction_axes: Optional[ReductionAxes] = None,
106+
precomputed_scale: Tensor = None,
107+
) -> Tuple[Tensor, Tensor]:
108+
weight_shape = weight.shape
109+
scale_shape = None if precomputed_scale is None else precomputed_scale.shape
110+
111+
ov_model_params = OVModelParameters()
112+
ov_model_params.input_dtypes["weight"] = weight.dtype
113+
if precomputed_scale is not None:
114+
ov_model_params.input_dtypes["scale"] = precomputed_scale.dtype
115+
if config.num_bits == 4 and weight.backend == TensorBackend.ov:
116+
# Return ov tensors in target precision to seamlessly insert them into openvino model later
117+
ov_model_params.return_ov_tensors = True
118+
dtype = TensorDataType.nf4 if config.mode == CompressWeightsMode.NF4 else TensorDataType.f4e2m1
119+
ov_model_params.output_dtypes.update({"compressed_weight": dtype})
120+
121+
model = get_float_quantization_model(
122+
ov_model_params,
123+
config,
124+
weight_shape,
125+
scale_shape,
126+
reduction_axes,
127+
)
128+
129+
if precomputed_scale is None:
130+
# weight -> compressed_weight, scale
131+
compressed_weight, scale = model([weight])
132+
133+
# Scale is always in fp32 so there is no need to store it in ov.Tensor
134+
if scale.backend == TensorBackend.ov:
135+
scale = scale.as_numpy_tensor()
136+
else:
137+
# weight, scale -> compressed_weight
138+
compressed_weight = model([weight, precomputed_scale])[0]
139+
scale = precomputed_scale
140+
141+
return compressed_weight, scale
142+
143+
100144
def integer_quantize_dequantize_weight(
101145
weight: Tensor,
102146
config: WeightCompressionConfig,

nncf/openvino/optimized_functions/models.py

+112
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from openvino.runtime import Node
2323
from openvino.runtime import opset13 as opset
2424

25+
from nncf import CompressWeightsMode
2526
from nncf.common.utils.backend import is_openvino_at_least
2627
from nncf.common.utils.caching import ResultsCache
2728
from nncf.common.utils.caching import cache_results
@@ -233,6 +234,26 @@ def get_integer_quantization_model(
233234
)
234235

235236

237+
def get_float_quantization_model(
238+
ov_model_params: OVModelParameters,
239+
config: WeightCompressionConfig,
240+
weight_shape: Tuple,
241+
scale_shape: Optional[Tuple] = None,
242+
reduction_axes: Optional[ReductionAxes] = None,
243+
) -> Union[ModelCallable, ModelAsNodes]:
244+
weight_shape, scale_shape, _ = _prepare_quantization_model_inputs(
245+
ov_model_params, weight_shape, scale_shape, zero_point_shape=None, reduction_axes=reduction_axes
246+
)
247+
248+
return _build_float_quantization_model(
249+
config,
250+
ov_model_params,
251+
weight_shape,
252+
scale_shape,
253+
reduction_axes,
254+
)
255+
256+
236257
def get_integer_quantize_dequantize_weight_model(
237258
ov_model_params: OVModelParameters,
238259
config: WeightCompressionConfig,
@@ -453,6 +474,97 @@ def _build_integer_quantization_model(
453474
return partial(_infer_ov_model, ov_model_params, compiled_model)
454475

455476

477+
@cache_results(OV_MODEL_CACHE)
478+
def _build_float_quantization_model(
479+
config: WeightCompressionConfig,
480+
ov_model_params: OVModelParameters,
481+
weight_shape: Tuple,
482+
scale_shape: Optional[Tuple] = None,
483+
reduction_axes: Optional[ReductionAxes] = None,
484+
return_nodes: bool = False,
485+
) -> Union[ModelCallable, ModelAsNodes]:
486+
default_input_dtypes = {"scale": TensorDataType.float32}
487+
default_output_dtypes = {"compressed_weight": TensorDataType.float32, "scale": TensorDataType.float32}
488+
489+
# Update input and output dtypes with the default values
490+
ov_model_params = copy.deepcopy(ov_model_params)
491+
ov_model_params.input_dtypes = {**default_input_dtypes, **ov_model_params.input_dtypes}
492+
ov_model_params.output_dtypes = {**default_output_dtypes, **ov_model_params.output_dtypes}
493+
494+
if "weight" not in ov_model_params.input_dtypes:
495+
msg = "Input weight dtype is required!"
496+
raise ValueError(msg)
497+
498+
weight_dtype = ov_model_params.input_dtypes["weight"]
499+
input_scale_dtype = ov_model_params.input_dtypes["scale"]
500+
compressed_weight_dtype = ov_model_params.output_dtypes["compressed_weight"]
501+
output_scale_dtype = ov_model_params.output_dtypes["scale"]
502+
503+
# Validate input dtypes
504+
valid_weight_dtypes = [TensorDataType.float32, TensorDataType.float16, TensorDataType.bfloat16]
505+
if weight_dtype not in valid_weight_dtypes:
506+
msg = f"Weight must be one of the following data types: {valid_weight_dtypes}. But found: {weight_dtype}."
507+
raise ValueError(msg)
508+
if scale_shape is not None and input_scale_dtype != TensorDataType.float32:
509+
msg = f"Input scale must be of float32 data type. But found: {input_scale_dtype}."
510+
raise ValueError(msg)
511+
512+
# Validate output dtypes
513+
# TODO: Enable f4e2m1
514+
valid_compressed_weight_dtypes = [TensorDataType.float32, TensorDataType.nf4]
515+
if compressed_weight_dtype not in valid_compressed_weight_dtypes:
516+
msg = (
517+
f"Compressed weight must be one of the following data types: {valid_compressed_weight_dtypes}. "
518+
f"But found: {compressed_weight_dtype}."
519+
)
520+
raise ValueError(msg)
521+
if scale_shape is None and output_scale_dtype != TensorDataType.float32:
522+
msg = f"Output scale must be of float32 data type. But found: {output_scale_dtype}."
523+
raise ValueError(msg)
524+
525+
# Build OV model
526+
weight = opset.parameter(weight_shape, name="weight", dtype=DTYPE_MAP_OV[weight_dtype])
527+
ov_parameters = [weight]
528+
weight = convert_op(weight, ov.Type.f32)
529+
530+
divide_op = opset.divide if ov_model_params.convertable_division else non_convertable_divide_op
531+
if scale_shape is not None:
532+
# Scale is given as an input
533+
scale = opset.parameter(scale_shape, name="scale", dtype=DTYPE_MAP_OV[input_scale_dtype])
534+
ov_parameters.append(scale)
535+
else:
536+
# Compute scale
537+
scale = opset.reduce_max(opset.abs(weight), reduction_axes=reduction_axes, keep_dims=True)
538+
# NOTE: adding machine epsilon to avoid division by zero
539+
eps = np.finfo(np.float32).eps
540+
scale = opset.select(opset.less(opset.abs(scale), eps), eps, scale)
541+
542+
if config.mode == CompressWeightsMode.E2M1:
543+
max_val = opset.constant(6, ov.Type.f32) # Maximal value of e2m1 type.
544+
constant_2 = opset.constant(2, ov.Type.f32)
545+
scale = divide_op(scale, max_val)
546+
scale = opset.log(scale) / opset.log(constant_2)
547+
scale = opset.ceil(scale)
548+
scale = opset.clamp(scale, -127, 127)
549+
scale = opset.power(constant_2, scale)
550+
551+
compressed_weight = divide_op(weight, scale)
552+
compressed_weight = convert_op(compressed_weight, ov.Type.nf4)
553+
compressed_weight = convert_op(compressed_weight, DTYPE_MAP_OV[compressed_weight_dtype])
554+
555+
ov_results = [compressed_weight]
556+
if len(ov_parameters) == 1:
557+
ov_results.append(scale)
558+
559+
if return_nodes:
560+
return ov_parameters, ov_results, ov_model_params
561+
562+
model = ov.Model(ov_results, ov_parameters)
563+
compiled_model = _compile_ov_model(model, device_name="CPU", config={inference_precision(): ov.Type.f32})
564+
565+
return partial(_infer_ov_model, ov_model_params, compiled_model)
566+
567+
456568
@cache_results(OV_MODEL_CACHE)
457569
def _build_integer_quantize_dequantize_weight_model(
458570
config: WeightCompressionConfig,

nncf/quantization/algorithms/weight_compression/awq.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@
3030
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
3131
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
3232
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
33-
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight
34-
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
3533
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization
34+
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization
3635
from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight
3736
from nncf.quantization.passes import transform_to_inference_graph
3837
from nncf.tensor import TensorDataType
@@ -255,8 +254,9 @@ def apply(
255254
cur_scale = gscale**alpha
256255
weights_to_fake_quantize = gweight * cur_scale
257256
if config.mode == CompressWeightsMode.NF4:
258-
g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis)
259-
g_compressed_weighs = calculate_nf4_quantized_weight(weights_to_fake_quantize, g_c_scale)
257+
g_compressed_weighs, g_c_scale = do_float_quantization(
258+
weights_to_fake_quantize, config, reduction_axis
259+
)
260260
g_decompressed_weighs = do_float_dequantization(g_compressed_weighs, g_c_scale)
261261
else:
262262
g_decompressed_weighs = integer_quantize_dequantize_weight(

nncf/quantization/algorithms/weight_compression/gptq.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
2727
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
2828
from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation
29+
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params
2930
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params
30-
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight
31-
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
3231
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization
32+
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization
3333
from nncf.quantization.algorithms.weight_compression.weight_lowering import integer_quantize_dequantize_weight
3434
from nncf.tensor import Tensor
3535
from nncf.tensor import functions as fns
@@ -262,7 +262,9 @@ def _quantize_weights(
262262

263263
if (i1 + i) % group_size == 0:
264264
if block_compression_config.mode == CompressWeightsMode.NF4:
265-
scale = calculate_nf4_scale(weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes)
265+
scale = calculate_float_quantization_params(
266+
weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, block_compression_config
267+
)
266268
scales.append(scale)
267269
else:
268270
if self._scale_estimation and block_compression_config.num_bits == 4:
@@ -284,8 +286,8 @@ def _quantize_weights(
284286
zero_points.append(zero_point)
285287

286288
if block_compression_config.mode == CompressWeightsMode.NF4:
287-
compressed_weights = calculate_nf4_quantized_weight(
288-
fns.unsqueeze(weight_col, 1), scales[-1], is_normalized_weight=False
289+
compressed_weights, _ = do_float_quantization(
290+
fns.unsqueeze(weight_col, 1), block_compression_config, precomputed_scale=scales[-1]
289291
)
290292
quantized_col = do_float_dequantization(compressed_weights, scales[-1], reduction_axis=-1)
291293
else:

nncf/quantization/algorithms/weight_compression/lora_correction.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
2626
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
2727
from nncf.quantization.algorithms.weight_compression.weight_lowering import CompressedWeight
28-
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight
2928
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization
3029
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_dequantization
3130
from nncf.tensor import Tensor
@@ -177,10 +176,7 @@ def calculate_low_rank_matrices(
177176
reduction_axis,
178177
)
179178
elif mode == CompressWeightsMode.NF4:
180-
indexes = calculate_nf4_quantized_weight(
181-
compressed_weight.tensor, compressed_weight.scale, is_normalized_weight=True
182-
)
183-
fq_weights = do_float_dequantization(indexes, compressed_weight.scale, reduction_axis)
179+
fq_weights = do_float_dequantization(compressed_weight.tensor, compressed_weight.scale, reduction_axis)
184180
else:
185181
msg = (
186182
f"{mode.value} mode is invalid for Lora Correction algorithm. Supported modes: INT4_SYM, INT4_ASYM, NF4"

nncf/quantization/algorithms/weight_compression/scale_estimation.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
2424
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
2525
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
26-
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_quantized_weight
2726
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_dequantization
2827
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_float_quantization
2928
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization
@@ -199,8 +198,7 @@ def calculate_quantization_params(
199198

200199
original_weight = fns.zeros_like(weight) + weight
201200
if config.mode == CompressWeightsMode.NF4:
202-
norm_weight, scale = do_float_quantization(original_weight, reduction_axis, cur_config.group_size)
203-
compressed_weights = calculate_nf4_quantized_weight(norm_weight, scale, is_normalized_weight=True)
201+
compressed_weights, scale = do_float_quantization(original_weight, cur_config, reduction_axis)
204202
q_weights = do_float_dequantization(compressed_weights, scale, reduction_axis)
205203
q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size)
206204
zp = None
@@ -249,7 +247,9 @@ def calculate_quantization_params(
249247
near_to_ideal_scale = near_to_ideal_scale * scale_sign
250248

251249
if config.mode == CompressWeightsMode.NF4:
252-
g_compressed_weighs = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale)
250+
g_compressed_weighs, _ = do_float_quantization(
251+
original_weight, config, precomputed_scale=near_to_ideal_scale
252+
)
253253
out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale)
254254
else:
255255
out = integer_quantize_dequantize_weight(
@@ -284,7 +284,7 @@ def calculate_quantization_params(
284284

285285
if i < initial_steps - 1:
286286
if config.mode == CompressWeightsMode.NF4:
287-
out = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale)
287+
out, _ = do_float_quantization(original_weight, config, precomputed_scale=near_to_ideal_scale)
288288
else:
289289
out, _, _ = do_integer_quantization(
290290
original_weight,
@@ -302,7 +302,7 @@ def calculate_quantization_params(
302302
scaled_scale = factor * scale
303303

304304
if config.mode == CompressWeightsMode.NF4:
305-
out = calculate_nf4_quantized_weight(original_weight, scaled_scale)
305+
out, _ = do_float_quantization(original_weight, config, precomputed_scale=scaled_scale)
306306
else:
307307
out, _, _ = do_integer_quantization(
308308
original_weight,
@@ -318,7 +318,9 @@ def calculate_quantization_params(
318318
near_to_ideal_scale = near_to_ideal_scale * scale_sign
319319

320320
if config.mode == CompressWeightsMode.NF4:
321-
g_compressed_weighs = calculate_nf4_quantized_weight(original_weight, near_to_ideal_scale)
321+
g_compressed_weighs, _ = do_float_quantization(
322+
original_weight, config, precomputed_scale=near_to_ideal_scale
323+
)
322324
out = do_float_dequantization(g_compressed_weighs, near_to_ideal_scale)
323325
else:
324326
out = integer_quantize_dequantize_weight(

0 commit comments

Comments
 (0)