Skip to content

Commit d6b41c1

Browse files
committed
replied to comments
1 parent 160d632 commit d6b41c1

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

nncf/quantization/algorithms/weight_compression/algorithm.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from nncf.quantization.algorithms.weight_compression.weight_lowering import WeightCompressionConfig
4646
from nncf.scopes import IgnoredScope
4747
from nncf.scopes import get_ignored_node_names_from_ignored_scope
48+
from nncf.tensor.definitions import TensorDataType
4849

4950
TModel = TypeVar("TModel")
5051
TTensor = TypeVar("TTensor")
@@ -56,6 +57,12 @@
5657
CompressWeightsMode.NF4,
5758
CompressWeightsMode.E2M1,
5859
]
60+
SUPPORTED_DATA_TYPES = [
61+
TensorDataType.float16,
62+
TensorDataType.bfloat16,
63+
TensorDataType.float32,
64+
TensorDataType.float64,
65+
]
5966

6067

6168
def get_weight_compression_configuration(
@@ -489,7 +496,7 @@ def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph)
489496
continue
490497
for _, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
491498
weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
492-
if weight_dtype.is_float():
499+
if weight_dtype in SUPPORTED_DATA_TYPES:
493500
continue
494501
weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
495502
weight_size = reduce(operator.mul, weight_shape, 1)
@@ -535,7 +542,7 @@ def apply(
535542
continue
536543

537544
weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
538-
if not weight_dtype.is_float():
545+
if weight_dtype not in SUPPORTED_DATA_TYPES:
539546
continue
540547
weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
541548
weight_size = reduce(operator.mul, weight_shape, 1)

nncf/tensor/definitions.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,15 @@ def is_float(self) -> bool:
5353
"""
5454
:return: True if the tensor data type is a floating-point type, else False.
5555
"""
56-
return self in [TensorDataType.float16, TensorDataType.bfloat16, TensorDataType.float32, TensorDataType.float64]
56+
return self in [
57+
TensorDataType.float16,
58+
TensorDataType.bfloat16,
59+
TensorDataType.float32,
60+
TensorDataType.float64,
61+
TensorDataType.f8e4m3,
62+
TensorDataType.f8e5m2,
63+
TensorDataType.nf4,
64+
]
5765

5866

5967
class TensorDeviceType(Enum):

0 commit comments

Comments
 (0)