@@ -203,7 +203,6 @@ def get_compress_weight_model(
203
203
scale_shape : Optional [Tuple ] = None ,
204
204
zero_point_shape : Optional [Tuple ] = None ,
205
205
reduction_axes : Optional [ReductionAxes ] = None ,
206
- return_nodes : Optional [bool ] = False ,
207
206
) -> Union [ModelCallable , ModelAsNodes ]:
208
207
"""
209
208
Get a model that compresses weights using the given configuration.
@@ -217,8 +216,6 @@ def get_compress_weight_model(
217
216
as an input.
218
217
:param reduction_axes: Optional axes to reduce the weight tensor. Not needed if scale (and z.p.) are provided as
219
218
inputs.
220
- :param return_nodes: Whether to return the OV model inputs parameters and results nodes instead of the model
221
- callable.
222
219
:return: A model callable that compresses weights using the given configuration. Or a model as nodes, if
223
220
`return_nodes` is True.
224
221
"""
@@ -233,7 +230,6 @@ def get_compress_weight_model(
233
230
scale_shape ,
234
231
zero_point_shape ,
235
232
reduction_axes ,
236
- return_nodes = return_nodes ,
237
233
)
238
234
239
235
@@ -278,6 +274,35 @@ def get_compress_decompress_weight_model(
278
274
)
279
275
280
276
277
+ def get_quantization_error_model (
278
+ ov_model_params : OVModelParameters ,
279
+ config : WeightCompressionConfig ,
280
+ original_weight_shape : Tuple ,
281
+ weight_shape : Tuple ,
282
+ original_reduction_axes : ReductionAxes ,
283
+ reduction_axes : ReductionAxes ,
284
+ ) -> ModelCallable :
285
+ """
286
+ Get a model that calculates the quantization error for a given weight.
287
+
288
+ This function builds a model that compresses and then decompresses the given weight, and calculates the
289
+ quantization error by comparing the original weight with the decompressed weight.
290
+
291
+ :param ov_model_params: OV model parameters.
292
+ :param config: Compression configuration.
293
+ :param original_weight_shape: Shape of the original weight tensor.
294
+ :param weight_shape: Shape of the weight tensor to be compressed.
295
+ :param original_reduction_axes: Reduction axes of the original weight tensor before reshaping.
296
+ :param reduction_axes: Axes to reduce the weight tensor.
297
+ :return: A model callable that returns the quantization error.
298
+ """
299
+ weight_shape , _ , _ = _prepare_compression_model_inputs (ov_model_params , weight_shape , None , None , reduction_axes )
300
+
301
+ return _build_quantization_error_model (
302
+ config , ov_model_params , original_weight_shape , weight_shape , original_reduction_axes , reduction_axes
303
+ )
304
+
305
+
281
306
@cache_results (OV_MODEL_CACHE )
282
307
def _build_compress_model (
283
308
config : WeightCompressionConfig ,
@@ -437,7 +462,8 @@ def _build_compress_decompress_model(
437
462
zero_point_shape : Optional [Tuple ] = None ,
438
463
reduction_axes : Optional [ReductionAxes ] = None ,
439
464
return_compressed_weight : Optional [bool ] = False ,
440
- ) -> ModelCallable :
465
+ return_nodes : Optional [bool ] = False ,
466
+ ) -> Union [ModelCallable , ModelAsNodes ]:
441
467
default_output_dtypes = {"decompressed_weight" : TensorDataType .float32 }
442
468
if not return_compressed_weight :
443
469
# If compressed weight is not returned to a user, we can keep it in float32 to avoid additional conversion
@@ -451,8 +477,8 @@ def _build_compress_decompress_model(
451
477
raise ValueError (msg )
452
478
453
479
# Get compression model as input/result nodes and potentially modified ov model parameters
454
- ov_parameters , ov_results , ov_model_params = get_compress_weight_model (
455
- ov_model_params , config , weight_shape , scale_shape , zero_point_shape , reduction_axes , return_nodes = True
480
+ ov_parameters , ov_results , ov_model_params = _build_compress_model (
481
+ config , ov_model_params , weight_shape , scale_shape , zero_point_shape , reduction_axes , return_nodes = True
456
482
)
457
483
458
484
if config .is_asym_mode :
@@ -477,12 +503,51 @@ def _build_compress_decompress_model(
477
503
decompressed_weight = opset .multiply (scale , convert_op (compressed_weight , ov .Type .f32 ))
478
504
479
505
ov_results = [decompressed_weight ] + ov_results if return_compressed_weight else [decompressed_weight ]
506
+
507
+ if return_nodes :
508
+ return ov_parameters , ov_results , ov_model_params
509
+
480
510
model = ov .Model (ov_results , ov_parameters )
481
511
compiled_model = _compile_ov_model (model , device_name = "CPU" , config = {inference_precision (): ov .Type .f32 })
482
512
483
513
return partial (_infer_ov_model , ov_model_params , compiled_model )
484
514
485
515
516
+ @cache_results (OV_MODEL_CACHE )
517
+ def _build_quantization_error_model (
518
+ config : WeightCompressionConfig ,
519
+ ov_model_params : OVModelParameters ,
520
+ original_weight_shape : Tuple ,
521
+ weight_shape : Tuple ,
522
+ original_reduction_axes : ReductionAxes ,
523
+ reduction_axes : ReductionAxes ,
524
+ ) -> ModelCallable :
525
+ ov_parameters , ov_results , ov_model_params = _build_compress_decompress_model (
526
+ config ,
527
+ ov_model_params ,
528
+ weight_shape ,
529
+ reduction_axes = reduction_axes ,
530
+ return_compressed_weight = False ,
531
+ return_nodes = True ,
532
+ )
533
+
534
+ weight = ov_parameters [0 ]
535
+ decompressed_weight = ov_results [0 ]
536
+
537
+ weight = convert_op (opset .reshape (weight , original_weight_shape , special_zero = False ), ov .Type .f32 )
538
+ decompressed_weight = convert_op (
539
+ opset .reshape (decompressed_weight , original_weight_shape , special_zero = False ), ov .Type .f32
540
+ )
541
+ diff = opset .squared_difference (decompressed_weight , weight )
542
+ layer_err = opset .reduce_mean (diff , reduction_axes = original_reduction_axes )
543
+ quantization_error = opset .reduce_max (layer_err , reduction_axes = tuple (range (len (layer_err .shape ))))
544
+
545
+ model = ov .Model ([quantization_error ], ov_parameters )
546
+ compiled_model = _compile_ov_model (model , device_name = "CPU" , config = {inference_precision (): ov .Type .f32 })
547
+
548
+ return partial (_infer_ov_model , ov_model_params , compiled_model )
549
+
550
+
486
551
def get_astype_model (ov_model_params : OVModelParameters , input_shape : Tuple ) -> ModelCallable :
487
552
"""
488
553
Return a model that cast the input of the given shape to the given data type. Especially useful for
0 commit comments