Skip to content

Commit 555f16e

Browse files
committed
Tests' fixes
1 parent d840360 commit 555f16e

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

nncf/quantization/algorithms/weight_compression/algorithm.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -160,18 +160,22 @@ def check_user_compression_configuration(
160160
msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified."
161161
raise nncf.ValidationError(msg)
162162

163-
for size in [
164-
subset_size,
165-
advanced_parameters.awq_params.subset_size,
166-
advanced_parameters.scale_estimation_params.subset_size,
167-
advanced_parameters.gptq_params.subset_size,
168-
advanced_parameters.lora_correction_params.subset_size,
169-
]:
163+
values_to_check = [subset_size]
164+
ranks = []
165+
if advanced_parameters:
166+
values_to_check.extend[
167+
advanced_parameters.awq_params.subset_size,
168+
advanced_parameters.scale_estimation_params.subset_size,
169+
advanced_parameters.gptq_params.subset_size,
170+
advanced_parameters.lora_correction_params.subset_size,
171+
]
172+
ranks = [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank]
173+
for size in values_to_check:
170174
if size <= 0:
171175
msg = f"The subset_size value should be positive, but subset_size={size} is given."
172176
raise nncf.ValidationError(msg)
173177

174-
for rank in [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank]:
178+
for rank in ranks:
175179
if rank <= 0:
176180
msg = f"The lora adapter rank should be positive, but rank={rank} is given."
177181
raise nncf.ValidationError(msg)

nncf/torch/model_transformer.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import contextlib
1213
from collections import defaultdict
1314
from functools import partial
1415
from typing import Callable, Dict, List, Optional, Tuple
@@ -20,6 +21,7 @@
2021
from nncf.common.graph.model_transformer import ModelTransformer
2122
from nncf.common.graph.transformations.commands import TargetType
2223
from nncf.common.graph.transformations.commands import TransformationPriority
24+
from nncf.errors import InternalError
2325
from nncf.torch.extractor import extract_model
2426
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
2527
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
@@ -80,15 +82,23 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor
8082
for command in transformation_layout.transformations:
8183
compression_module = command.fn
8284
if isinstance(compression_module, nn.Module):
83-
target_point = command.target_points[0]
84-
node_with_weight = graph.get_node_by_name(target_point.target_node_name)
85-
weight_node = get_const_node(node_with_weight, target_point.input_port_id, graph)
86-
if weight_node is None:
87-
weight_node = node_with_weight # Decompression in DQ compression format is applied to const.
88-
const_data = get_const_data(weight_node, model)
89-
# Compression module and the corresponding layer may have a different device in multi-device setup
90-
# (e.g. when HF model was loaded with device_map='auto'). Need to align devices.
91-
compression_module.to(const_data.device)
85+
points = [command.target_point]
86+
if hasattr(command, "target_points"):
87+
points = command.target_points
88+
for target_point in points:
89+
target_node = graph.get_node_by_name(target_point.target_node_name)
90+
weight_node = None
91+
with contextlib.suppress(InternalError):
92+
weight_node = get_const_node(target_node, target_point.input_port_id, graph)
93+
if weight_node is None:
94+
weight_node = target_node # Decompression in DQ compression format is applied to const
95+
const_data = None
96+
with contextlib.suppress(AttributeError):
97+
const_data = get_const_data(weight_node, model)
98+
if const_data is not None:
99+
# Compression module and the corresponding layer may have a different device in multi-device
100+
# setup (e.g. when HF model was loaded with device_map='auto'). Need to align devices.
101+
compression_module.to(const_data.device)
92102
model.nncf.rebuild_graph()
93103

94104
return model

0 commit comments

Comments
 (0)