@@ -417,13 +417,13 @@ def _quantize_torchmodel(
417
417
model = patch_model_with_bettertransformer (model )
418
418
419
419
dummy_inputs = onnx_config .generate_dummy_inputs (framework = "pt" )
420
- device = self . model .device
420
+ device = model .device
421
421
dummy_inputs = tree_map (
422
422
lambda value : value .to (device ) if isinstance (value , torch .Tensor ) else value , dummy_inputs
423
423
)
424
424
check_dummy_inputs_are_allowed (model , dummy_inputs )
425
425
426
- nncf .compress_weights (self . model , dataset = nncf .Dataset ([dummy_inputs ]))
426
+ nncf .compress_weights (model , dataset = nncf .Dataset ([dummy_inputs ]))
427
427
else :
428
428
if stateful :
429
429
logger .warn (
@@ -443,10 +443,10 @@ def _quantize_torchmodel(
443
443
quantization_config .add_input_info (model_inputs )
444
444
nncf_config = NNCFConfig .from_dict (quantization_config .__dict__ )
445
445
nncf_config = register_default_init_args (nncf_config , calibration_dataloader )
446
- controller , compressed_model = create_compressed_model (
447
- self . model , nncf_config , wrap_inputs_fn = wrap_nncf_model_inputs_with_objwalk
446
+ controller , model = create_compressed_model (
447
+ model , nncf_config , wrap_inputs_fn = wrap_nncf_model_inputs_with_objwalk
448
448
)
449
- compressed_model = controller .strip (do_copy = False )
449
+ model = controller .strip (do_copy = False )
450
450
451
451
model_path = save_directory / (onnx_file_name if quantization_config .save_onnx_model else ov_file_name )
452
452
onnx_path = save_directory / onnx_file_name
0 commit comments