|
45 | 45 | from nncf.quantization.algorithms.weight_compression.weight_lowering import WeightCompressionConfig
|
46 | 46 | from nncf.scopes import IgnoredScope
|
47 | 47 | from nncf.scopes import get_ignored_node_names_from_ignored_scope
|
| 48 | +from nncf.tensor.definitions import TensorDataType |
48 | 49 |
|
49 | 50 | TModel = TypeVar("TModel")
|
50 | 51 | TTensor = TypeVar("TTensor")
|
|
56 | 57 | CompressWeightsMode.NF4,
|
57 | 58 | CompressWeightsMode.E2M1,
|
58 | 59 | ]
|
| 60 | +SUPPORTED_DATA_TYPES = [ |
| 61 | + TensorDataType.float16, |
| 62 | + TensorDataType.bfloat16, |
| 63 | + TensorDataType.float32, |
| 64 | + TensorDataType.float64, |
| 65 | +] |
59 | 66 |
|
60 | 67 |
|
61 | 68 | def get_weight_compression_configuration(
|
@@ -489,7 +496,7 @@ def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph)
|
489 | 496 | continue
|
490 | 497 | for _, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
|
491 | 498 | weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
|
492 |
| - if weight_dtype.is_float(): |
| 499 | + if weight_dtype in SUPPORTED_DATA_TYPES: |
493 | 500 | continue
|
494 | 501 | weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
|
495 | 502 | weight_size = reduce(operator.mul, weight_shape, 1)
|
@@ -535,7 +542,7 @@ def apply(
|
535 | 542 | continue
|
536 | 543 |
|
537 | 544 | weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
|
538 |
| - if not weight_dtype.is_float(): |
| 545 | + if weight_dtype not in SUPPORTED_DATA_TYPES: |
539 | 546 | continue
|
540 | 547 | weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
|
541 | 548 | weight_size = reduce(operator.mul, weight_shape, 1)
|
|
0 commit comments