|
46 | 46 | from nncf.quantization.algorithms.weight_compression.weight_lowering import WeightCompressionConfig
|
47 | 47 | from nncf.scopes import IgnoredScope
|
48 | 48 | from nncf.scopes import get_ignored_node_names_from_ignored_scope
|
| 49 | +from nncf.tensor.definitions import TensorDataType |
49 | 50 |
|
50 | 51 | TModel = TypeVar("TModel")
|
51 | 52 | TTensor = TypeVar("TTensor")
|
|
57 | 58 | CompressWeightsMode.NF4,
|
58 | 59 | CompressWeightsMode.E2M1,
|
59 | 60 | ]
|
| 61 | +SUPPORTED_DATA_TYPES = [ |
| 62 | + TensorDataType.float16, |
| 63 | + TensorDataType.bfloat16, |
| 64 | + TensorDataType.float32, |
| 65 | + TensorDataType.float64, |
| 66 | +] |
60 | 67 |
|
61 | 68 |
|
62 | 69 | def get_weight_compression_configuration(
|
@@ -160,9 +167,27 @@ def check_user_compression_configuration(
|
160 | 167 | msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified."
|
161 | 168 | raise nncf.ValidationError(msg)
|
162 | 169 |
|
163 |
| - if subset_size <= 0: |
164 |
| - msg = f"The subset_size value should be positive, but subset_size={subset_size} is given." |
165 |
| - raise nncf.ValidationError(msg) |
| 170 | + values_to_check = [subset_size] |
| 171 | + ranks = [] |
| 172 | + if advanced_parameters: |
| 173 | + values_to_check.extend( |
| 174 | + [ |
| 175 | + advanced_parameters.awq_params.subset_size, |
| 176 | + advanced_parameters.scale_estimation_params.subset_size, |
| 177 | + advanced_parameters.gptq_params.subset_size, |
| 178 | + advanced_parameters.lora_correction_params.subset_size, |
| 179 | + ] |
| 180 | + ) |
| 181 | + ranks = [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank] |
| 182 | + for size in values_to_check: |
| 183 | + if size <= 0: |
| 184 | + msg = f"The subset_size value should be positive, but subset_size={size} is given." |
| 185 | + raise nncf.ValidationError(msg) |
| 186 | + |
| 187 | + for rank in ranks: |
| 188 | + if rank <= 0: |
| 189 | + msg = f"The lora adapter rank should be positive, but rank={rank} is given." |
| 190 | + raise nncf.ValidationError(msg) |
166 | 191 |
|
167 | 192 | if (
|
168 | 193 | ratio
|
@@ -498,7 +523,7 @@ def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph)
|
498 | 523 | continue
|
499 | 524 | for _, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
|
500 | 525 | weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
|
501 |
| - if weight_dtype.is_float(): |
| 526 | + if weight_dtype in SUPPORTED_DATA_TYPES: |
502 | 527 | continue
|
503 | 528 | weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
|
504 | 529 | weight_size = reduce(operator.mul, weight_shape, 1)
|
@@ -544,7 +569,7 @@ def apply(
|
544 | 569 | continue
|
545 | 570 |
|
546 | 571 | weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
|
547 |
| - if not weight_dtype.is_float(): |
| 572 | + if weight_dtype not in SUPPORTED_DATA_TYPES: |
548 | 573 | continue
|
549 | 574 | weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
|
550 | 575 | weight_size = reduce(operator.mul, weight_shape, 1)
|
@@ -656,6 +681,7 @@ def apply(
|
656 | 681 | zero_points,
|
657 | 682 | lora_correction_algo,
|
658 | 683 | self._compression_format,
|
| 684 | + self._advanced_parameters, |
659 | 685 | )
|
660 | 686 |
|
661 | 687 | self._backend_entity.dump_parameters(
|
|
0 commit comments