21
21
from nncf .tensor .functions import numeric
22
22
23
23
DTYPE_MAP : Dict [TensorDataType , ov .Type ] = {
24
+ TensorDataType .nf4 : ov .Type .nf4 ,
25
+ TensorDataType .f8e4m3 : ov .Type .f8e4m3 ,
26
+ TensorDataType .f8e5m2 : ov .Type .f8e5m2 ,
24
27
TensorDataType .float16 : ov .Type .f16 ,
25
28
TensorDataType .bfloat16 : ov .Type .bf16 ,
26
29
TensorDataType .float32 : ov .Type .f32 ,
@@ -48,12 +51,17 @@ def _(a: ov.Tensor) -> TensorBackend:
48
51
49
52
@numeric .astype .register
50
53
def _ (a : ov .Tensor , dtype : TensorDataType ) -> ov .Tensor :
51
- if a . get_element_type () in [ ov . Type . bf16 , ov . Type . i4 , ov . Type . u4 ] or dtype in [
54
+ ov_cast_types = [
52
55
TensorDataType .bfloat16 ,
53
56
TensorDataType .int4 ,
54
57
TensorDataType .uint4 ,
55
- ]:
56
- # Cannot cast to/from bfloat16, uint4, int4 directly
58
+ TensorDataType .nf4 ,
59
+ TensorDataType .f8e4m3 ,
60
+ TensorDataType .f8e5m2 ,
61
+ ]
62
+ a_dtype = DTYPE_MAP_REV [a .get_element_type ()]
63
+ if a_dtype in ov_cast_types or dtype in ov_cast_types :
64
+ # Cast using OpenVINO because the target or source dtype requires special handling
57
65
return _astype_ov (a , dtype )
58
66
return ov .Tensor (numeric .astype (a .data , dtype ).data )
59
67
@@ -75,9 +83,16 @@ def _(a: ov.Tensor, shape: Union[int, Tuple[int, ...]]) -> ov.Tensor:
75
83
76
84
@numeric .as_numpy_tensor .register
77
85
def _ (a : ov .Tensor ) -> NDArray [Any ]:
78
- # Cannot convert bfloat16, uint4, int4 to numpy directly
86
+ # Cannot convert bfloat16, uint4, int4, nf4, f8e4m3, f8e5m2 to numpy directly
79
87
a_dtype = DTYPE_MAP_REV [a .get_element_type ()]
80
- if a_dtype in [TensorDataType .bfloat16 , TensorDataType .uint4 , TensorDataType .int4 ]:
88
+ if a_dtype in [
89
+ TensorDataType .bfloat16 ,
90
+ TensorDataType .uint4 ,
91
+ TensorDataType .int4 ,
92
+ TensorDataType .nf4 ,
93
+ TensorDataType .f8e4m3 ,
94
+ TensorDataType .f8e5m2 ,
95
+ ]:
81
96
dtype = TensorDataType .float32
82
97
if a_dtype == TensorDataType .uint4 :
83
98
dtype = TensorDataType .uint8
0 commit comments