@@ -893,6 +893,48 @@ struct ConvertFrom4BitPrecision<std::tuple<src_t, dst_t>> {
893
893
}
894
894
};
895
895
896
+ #define INTEL_CPU_CVT_TO_4BIT_LIST \
897
+ INTEL_CPU_CVT (f32, nf4), INTEL_CPU_CVT(f16, nf4), INTEL_CPU_CVT(bf16, nf4)
898
+
899
+ struct ConvertTo4BitContext {
900
+ ov::element::Type_t outType;
901
+ const void * srcPtr;
902
+ void * dstPtr;
903
+ size_t size;
904
+ bool converted;
905
+ };
906
+
907
+ template <typename T>
908
+ struct ConvertTo4BitPrecision ;
909
+
910
+ template <typename src_t , typename dst_t >
911
+ struct ConvertTo4BitPrecision <std::tuple<src_t , dst_t >> {
912
+ void operator ()(ConvertTo4BitContext& ctx) {
913
+ auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t {
914
+ uint8_t shift = high_half ? 4 : 0 ;
915
+ return dst | (uint8_t ) (val << shift);
916
+ };
917
+
918
+ auto src = static_cast <const src_t *>(ctx.srcPtr );
919
+ auto dst = static_cast <uint8_t *>(ctx.dstPtr );
920
+ // each byte must be fully processed within same thread
921
+ auto work_amount = div_up (ctx.size , 2 );
922
+ if (ctx.outType == ov::element::nf4) {
923
+ parallel_for (work_amount, [&](size_t ib) {
924
+ for (int i = 0 ; i < 2 ; i++) {
925
+ int idx = ib * 2 + i;
926
+ uint8_t val = idx % 2 == 0 ? 0 : dst[idx / 2 ];
927
+ val = insert_half_byte (val, ConvertNF4::quantize (static_cast <float >(src[idx])), idx % 2 );
928
+ dst[idx / 2 ] = val;
929
+ }
930
+ });
931
+ } else {
932
+ OPENVINO_THROW (" cpu_convert doesn't support output data type: " , ctx.outType , " . Not implemented." );
933
+ }
934
+ ctx.converted = true ;
935
+ }
936
+ };
937
+
896
938
#define INTEL_CPU_CVT_FROM_BYTE_FP_LIST \
897
939
INTEL_CPU_CVT (f8e8m0, f32), INTEL_CPU_CVT(f8e8m0, bf16), INTEL_CPU_CVT(f8e8m0, f16)
898
940
@@ -1017,6 +1059,12 @@ void cpu_convert(const void* srcPtr,
1017
1059
if (!ctx.converted ) {
1018
1060
OPENVINO_THROW (" cpu_convert can't convert from: " , srcPrc, " precision to: " , dstPrc);
1019
1061
}
1062
+ } else if (dstPrc.bitwidth () == 4u ) {
1063
+ ConvertTo4BitContext ctx{dstPrc, srcPtr, dstPtr, size, false };
1064
+ OV_SWITCH (intel_cpu, ConvertTo4BitPrecision, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
1065
+ if (!ctx.converted ) {
1066
+ OPENVINO_THROW (" cpu_convert can't convert from: " , srcPrc, " precision to: " , dstPrc);
1067
+ }
1020
1068
} else if (srcPrc == ov::element::f8e8m0) {
1021
1069
ConvertFromByteFPContext ctx{srcPrc, srcPtr, dstPtr, size, false };
1022
1070
OV_SWITCH (intel_cpu,
@@ -1063,6 +1111,7 @@ bool is_supported_convert(ov::element::Type srcPrc, ov::element::Type dstPrc) {
1063
1111
OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BIN_LIST);
1064
1112
OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_FROM_4BIT_LIST);
1065
1113
OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BYTE_FP_LIST);
1114
+ OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
1066
1115
return ctx.isSupported ;
1067
1116
}
1068
1117
0 commit comments