diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py
index 6169f4bada3..d484a7bd4ee 100644
--- a/nncf/quantization/algorithms/weight_compression/algorithm.py
+++ b/nncf/quantization/algorithms/weight_compression/algorithm.py
@@ -160,9 +160,27 @@ def check_user_compression_configuration(
         msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified."
         raise nncf.ValidationError(msg)
 
-    if subset_size <= 0:
-        msg = f"The subset_size value should be positive, but subset_size={subset_size} is given."
-        raise nncf.ValidationError(msg)
+    values_to_check = [subset_size]
+    ranks = []
+    if advanced_parameters:
+        values_to_check.extend(
+            [
+                advanced_parameters.awq_params.subset_size,
+                advanced_parameters.scale_estimation_params.subset_size,
+                advanced_parameters.gptq_params.subset_size,
+                advanced_parameters.lora_correction_params.subset_size,
+            ]
+        )
+        ranks = [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank]
+    for size in values_to_check:
+        if size <= 0:
+            msg = f"The subset_size value should be positive, but subset_size={size} is given."
+            raise nncf.ValidationError(msg)
+
+    for rank in ranks:
+        if rank <= 0:
+            msg = f"The lora adapter rank should be positive, but rank={rank} is given."
+            raise nncf.ValidationError(msg)
 
     if (
         ratio
@@ -656,6 +674,7 @@ def apply(
             zero_points,
             lora_correction_algo,
             self._compression_format,
+            self._advanced_parameters,
         )
 
         self._backend_entity.dump_parameters(
diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py
index 79c12868d0a..46629bc380e 100644
--- a/nncf/torch/quantization/layers.py
+++ b/nncf/torch/quantization/layers.py
@@ -768,6 +768,10 @@ def signed(self, signed: bool):
         self.set_levels()
 
     def quantize(self, x, execute_traced_op_as_identity: bool = False):
+        # TODO: (dokuchaev) remove within new tracing (ticket-163869)
+        with DisableTorchFunction():
+            # in multi-device case after loading nncf checkpoint, quantizers have a different device.
+            self.to(x.device)
         return symmetric_quantize(
             x, self.levels, self.level_low, self.level_high, self.scale, self.eps, skip=execute_traced_op_as_identity
         )
@@ -955,6 +959,10 @@ def set_levels(self):
         self.level_low, self.level_high = calculate_asymmetric_level_ranges(self.num_bits - scaled_num_bits)
 
     def quantize(self, x, execute_traced_op_as_identity: bool = False):
+        # TODO: (dokuchaev) remove within new tracing (ticket-163869)
+        with DisableTorchFunction():
+            # in multi-device case after loading nncf checkpoint, quantizers have a different device.
+            self.to(x.device)
         return asymmetric_quantize(
             x,
             self.levels,
@@ -1067,9 +1075,14 @@ class LoraMixin:
 
     def init_lora(self, lspec: PTLoraSpec):
         self._lspec = lspec
+        default_lora_dtype = torch.bfloat16
         out_features, in_features = lspec.orig_weight_shape
-        self.lora_A = torch.nn.Parameter(torch.ones((lspec.lora_rank, in_features), dtype=torch.bfloat16))
-        self.lora_B = torch.nn.Parameter(torch.zeros((out_features, lspec.lora_rank), dtype=torch.bfloat16))
+        rank = lspec.lora_rank
+        if rank > out_features or rank > in_features:
+            msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor"
+            raise nncf.ValidationError(msg)
+        self.lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype))
+        self.lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype))
 
     def enable_gradients(self):
         self.lora_A.requires_grad = True
@@ -1097,6 +1110,10 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
         self.init_lora(lspec)
 
     def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False):
+        # TODO: (dokuchaev) remove within new tracing (ticket-163869)
+        with DisableTorchFunction():
+            # in multi-device case after loading nncf checkpoint, quantizers have a different device.
+            self.to(x.device)
         return asymmetric_quantize_lora(
             x,
             self._lspec.weight_shape,
@@ -1142,6 +1159,10 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
         self.init_lora(lspec)
 
     def quantize(self, x, execute_traced_op_as_identity: bool = False):
+        # TODO: (dokuchaev) remove within new tracing (ticket-163869)
+        with DisableTorchFunction():
+            # in multi-device case after loading nncf checkpoint, quantizers have a different device.
+            self.to(x.device)
         return symmetric_quantize_lora(
             x,
             self._lspec.weight_shape,
diff --git a/tests/torch/ptq/test_fq_lora.py b/tests/torch/ptq/test_fq_lora.py
index 327733c95b8..a6c667b6823 100644
--- a/tests/torch/ptq/test_fq_lora.py
+++ b/tests/torch/ptq/test_fq_lora.py
@@ -15,9 +15,18 @@
 from transformers import AutoTokenizer
 
 import nncf
+from nncf.data.dataset import Dataset
+from nncf.errors import ValidationError
+from nncf.parameters import CompressionFormat
+from nncf.parameters import CompressWeightsMode
+from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
+from nncf.quantization.quantize_model import compress_weights
+from nncf.scopes import IgnoredScope
+from nncf.torch import load_from_config
 from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ
 from nncf.torch.quantization.layers import LoraMixin
 from nncf.torch.quantization.layers import SymmetricQuantizer as SQ
+from tests.torch.test_models.synthetic import LinearModel
 
 
 @pytest.mark.parametrize(
@@ -80,3 +89,67 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable
 
     assert first_loss > 8
     assert float(loss) < 1
+
+
+def test_checkpoint_loading(tmp_path):
+    model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
+    if not torch.cuda.is_available():
+        pytest.skip("Skipping CUDA test case for CPU only setups.")
+    device = "cuda"
+    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    example_input = tokenizer("dummy", return_tensors="pt").to(device)
+    except_lm_head_and_5th_vproj = (
+        r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$"
+    )
+    model = compress_weights(
+        model,
+        group_size=32,
+        mode=CompressWeightsMode.INT4_ASYM,
+        backup_mode=CompressWeightsMode.INT8_ASYM,
+        dataset=Dataset([dict(example_input)]),
+        compression_format=CompressionFormat.FQ_LORA,
+        ignored_scope=IgnoredScope(patterns=[except_lm_head_and_5th_vproj]),
+        advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=2),
+    )
+    ref_output = tokenizer.decode(
+        model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True
+    )
+
+    # save checkpoint
+    ckpt_path = tmp_path / "nncf_ckpt.pth"
+    torch.save(
+        {
+            "nncf_state_dict": model.nncf.state_dict(),
+            "nncf_config": model.nncf.get_config(),
+        },
+        ckpt_path,
+    )
+    del model
+
+    # load checkpoint
+    nncf_ckpt = torch.load(ckpt_path, weights_only=False)
+    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
+    model = load_from_config(model, nncf_ckpt["nncf_config"], example_input=dict(example_input))
+    model.nncf.load_state_dict(nncf_ckpt["nncf_state_dict"])
+
+    actual_output = tokenizer.decode(
+        model.generate(**example_input, do_sample=False, max_new_tokens=20)[0],
+        skip_special_tokens=True,
+    )
+    assert actual_output == ref_output
+
+
+def test_invalid_lora_rank():
+    too_big_rank = 4
+    model = LinearModel(torch.ones(2, 2))
+    with pytest.raises(ValidationError):
+        compress_weights(
+            model,
+            mode=CompressWeightsMode.INT4_ASYM,
+            group_size=2,
+            all_layers=True,
+            dataset=Dataset([torch.ones(2, 2)]),
+            compression_format=CompressionFormat.FQ_LORA,
+            advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=too_big_rank),
+        )
diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py
index b71c4179ca3..7a9455c7be4 100644
--- a/tests/torch/ptq/test_weights_compression.py
+++ b/tests/torch/ptq/test_weights_compression.py
@@ -39,6 +39,7 @@
 from nncf.torch.quantization.quantize_functions import unpack_int4
 from nncf.torch.quantization.quantize_functions import unpack_uint4
 from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
+from tests.torch.test_models.synthetic import LinearModel
 from tests.torch.test_models.synthetic import ShortTransformer
 from tests.torch.test_tensor import cast_to
 
@@ -82,16 +83,6 @@ def forward(self, input):
         return input @ self.w
 
 
-class LinearModel(torch.nn.Module):
-    def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
-        super().__init__()
-        self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
-        self.linear.weight = torch.nn.Parameter(weight)
-
-    def forward(self, input):
-        return self.linear(input)
-
-
 class AWQActLinearModel(nn.Module):
     def __init__(self, with_multiply=False, n_layers=8):
         super().__init__()
diff --git a/tests/torch/quantization/test_layers.py b/tests/torch/quantization/test_layers.py
index 97cbb4c41ab..e42a2e88c7b 100644
--- a/tests/torch/quantization/test_layers.py
+++ b/tests/torch/quantization/test_layers.py
@@ -36,7 +36,7 @@ def test_quantizer_layers_accepts_return_type(registred):
     )
     if mode in [QuantizationMode.ASYMMETRIC_LORA, QuantizationMode.SYMMETRIC_LORA]:
         shape = actual_input.unsqueeze(dim=0).shape
-        lora_spec = PTLoraSpec(2, shape, shape)
+        lora_spec = PTLoraSpec(0, shape, shape)
         quantizer = quantizer_cls(quantizer_spec, lora_spec)
     else:
         quantizer = quantizer_cls(quantizer_spec)
diff --git a/tests/torch/test_models/synthetic.py b/tests/torch/test_models/synthetic.py
index e64a234d1c6..c6b9d9678c5 100644
--- a/tests/torch/test_models/synthetic.py
+++ b/tests/torch/test_models/synthetic.py
@@ -660,6 +660,16 @@ def forward(self, input_ids):
         return res
 
 
+class LinearModel(torch.nn.Module):
+    def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
+        super().__init__()
+        self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
+        self.linear.weight = torch.nn.Parameter(weight)
+
+    def forward(self, input):
+        return self.linear(input)
+
+
 class YOLO11N_SDPABlock(torch.nn.Module):
     INPUT_SIZE = (1, 2, 4)