Skip to content

Commit 5414dd6

Browse files
OV nf4 support (#3209)
### Changes - Added `nf4` precision for OV `GraphConverter`. ### Reason for changes - `nf4` precision support. ### Related tickets - 153357 ### Tests - Added
1 parent f3f232f commit 5414dd6

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

nncf/openvino/graph/nncf_graph_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def convert_to_nncf_dtype(ov_type: ov.Type) -> Dtype:
4848
"bf16": "float",
4949
"f32": "float",
5050
"f64": "float",
51+
"nf4": "float",
5152
"i4": "int",
5253
"i8": "int",
5354
"i16": "int",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
strict digraph {
2+
"0 Input" [id=0, type=Parameter];
3+
"1 MatMul" [id=1, type=MatMul];
4+
"2 Add" [id=2, type=Add];
5+
"3 Result_Add" [id=3, type=Result];
6+
"4 Convert_6" [id=4, type=Convert];
7+
"5 MatMul_bias" [id=5, type=Constant];
8+
"6 Convert_3" [id=6, type=Convert];
9+
"7 MatMul_const" [id=7, type=Constant];
10+
"0 Input" -> "1 MatMul" [label="[1, 3, 4, 2]", style=solid];
11+
"1 MatMul" -> "2 Add" [label="[1, 3, 2, 5]", style=solid];
12+
"2 Add" -> "3 Result_Add" [label="[1, 3, 2, 5]", style=solid];
13+
"4 Convert_6" -> "2 Add" [label="[1, 3, 1, 1]", style=solid];
14+
"5 MatMul_bias" -> "4 Convert_6" [label="[1, 3, 1, 1]", style=solid];
15+
"6 Convert_3" -> "1 MatMul" [label="[1, 3, 4, 5]", style=solid];
16+
"7 MatMul_const" -> "6 Convert_3" [label="[1, 3, 4, 5]", style=solid];
17+
}

tests/openvino/native/test_nncf_graph_builder.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tests.openvino.native.common import convert_torch_model
2222
from tests.openvino.native.common import get_actual_reference_for_current_openvino
2323
from tests.openvino.native.models import SYNTHETIC_MODELS
24+
from tests.openvino.native.models import FPModel
2425
from tests.openvino.native.models import ParallelEdgesModel
2526
from tests.openvino.native.models import get_torch_model_info
2627

@@ -34,6 +35,19 @@ def test_compare_nncf_graph_synthetic_models(model_cls_to_test):
3435
compare_nncf_graphs(model_to_test.ov_model, path_to_dot)
3536

3637

38+
@pytest.mark.parametrize(
39+
"model,precision",
40+
[
41+
(FPModel(const_dtype=ov.Type.nf4), "nf4"),
42+
],
43+
)
44+
def test_compare_nncf_graph_precision_synthetic_models(model, precision):
45+
path_to_dot = get_actual_reference_for_current_openvino(
46+
REFERENCE_GRAPHS_DIR / f"{precision}_{model.ref_graph_name}"
47+
)
48+
compare_nncf_graphs(model.ov_model, path_to_dot)
49+
50+
3751
@pytest.mark.parametrize(
3852
"model_name",
3953
(
@@ -101,6 +115,7 @@ def _get_default_nncf_graph_edge(from_node, to_node, input_port_id, output_port_
101115
(ov.Type.f16, Dtype.FLOAT),
102116
(ov.Type.f32, Dtype.FLOAT),
103117
(ov.Type.f64, Dtype.FLOAT),
118+
(ov.Type.nf4, Dtype.FLOAT),
104119
(ov.Type.i4, Dtype.INTEGER),
105120
(ov.Type.i8, Dtype.INTEGER),
106121
(ov.Type.i16, Dtype.INTEGER),
@@ -124,7 +139,6 @@ def test_convert_to_nncf_dtype_supported_types(ov_type: ov.Type, expected_nncf_d
124139
@pytest.mark.parametrize(
125140
"ov_type",
126141
[
127-
ov.Type.nf4,
128142
ov.Type.undefined,
129143
ov.Type.f8e4m3,
130144
ov.Type.f8e5m2,

0 commit comments

Comments
 (0)