@@ -413,13 +413,13 @@ def _quantize_torchmodel(
413
413
model = patch_model_with_bettertransformer (model )
414
414
415
415
dummy_inputs = onnx_config .generate_dummy_inputs (framework = "pt" )
416
- device = self . model .device
416
+ device = model .device
417
417
dummy_inputs = tree_map (
418
418
lambda value : value .to (device ) if isinstance (value , torch .Tensor ) else value , dummy_inputs
419
419
)
420
420
check_dummy_inputs_are_allowed (model , dummy_inputs )
421
421
422
- nncf .compress_weights (self . model , dataset = nncf .Dataset ([dummy_inputs ]))
422
+ nncf .compress_weights (model , dataset = nncf .Dataset ([dummy_inputs ]))
423
423
else :
424
424
if stateful :
425
425
logger .warn (
@@ -439,10 +439,10 @@ def _quantize_torchmodel(
439
439
quantization_config .add_input_info (model_inputs )
440
440
nncf_config = NNCFConfig .from_dict (quantization_config .__dict__ )
441
441
nncf_config = register_default_init_args (nncf_config , calibration_dataloader )
442
- controller , compressed_model = create_compressed_model (
443
- self . model , nncf_config , wrap_inputs_fn = wrap_nncf_model_inputs_with_objwalk
442
+ controller , model = create_compressed_model (
443
+ model , nncf_config , wrap_inputs_fn = wrap_nncf_model_inputs_with_objwalk
444
444
)
445
- compressed_model = controller .strip (do_copy = False )
445
+ model = controller .strip (do_copy = False )
446
446
447
447
model_path = save_directory / (onnx_file_name if quantization_config .save_onnx_model else ov_file_name )
448
448
onnx_path = save_directory / onnx_file_name
0 commit comments