Skip to content

Commit 569d3f2

Browse files
author
dmitrygo
committed
[CPU] Enabled float (fp32/fp16/bf16) to nf4 precision conversion
1 parent c27f796 commit 569d3f2

File tree

5 files changed

+100
-22
lines changed

5 files changed

+100
-22
lines changed

src/plugins/intel_cpu/src/node.cpp

+2-20
Original file line numberDiff line numberDiff line change
@@ -1588,24 +1588,6 @@ ov::element::Type Node::getRuntimePrecision() const {
15881588
}
15891589

15901590
Node* Node::NodesFactory::create(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context) {
1591-
// getExceptionDescWithoutStatus removes redundant information from the exception message. For instance, the
1592-
// NotImplemented exception is generated in the form: full_path_to_src_file:line_number [ NOT_IMPLEMENTED ] reason.
1593-
// An example for gather node:
1594-
// /path-to-openVino-root/src/plugins/intel_cpu/nodes/gather.cpp:42 [ NOT_IMPLEMENTED ] Only opset7 Gather operation
1595-
// is supported The most important part of the message is the reason, so the lambda trims everything up to "]" Note
1596-
// that the op type and its friendly name will also be provided if we fail to create the node.
1597-
auto getExceptionDescWithoutStatus = [](const ov::Exception& ex) {
1598-
std::string desc = ex.what();
1599-
size_t pos = desc.find(']');
1600-
if (pos != std::string::npos) {
1601-
if (desc.size() == pos + 1) {
1602-
desc.erase(0, pos + 1);
1603-
} else {
1604-
desc.erase(0, pos + 2);
1605-
}
1606-
}
1607-
return desc;
1608-
};
16091591
Node* newNode = nullptr;
16101592
std::string errorMessage;
16111593
if (newNode == nullptr) {
@@ -1616,7 +1598,7 @@ Node* Node::NodesFactory::create(const std::shared_ptr<ov::Node>& op, const Grap
16161598
}
16171599
} catch (const ov::Exception& ex) {
16181600
if (dynamic_cast<const ov::NotImplemented*>(&ex) != nullptr) {
1619-
errorMessage += getExceptionDescWithoutStatus(ex);
1601+
errorMessage += ex.what();
16201602
} else {
16211603
throw;
16221604
}
@@ -1631,7 +1613,7 @@ Node* Node::NodesFactory::create(const std::shared_ptr<ov::Node>& op, const Grap
16311613
}
16321614
} catch (const ov::Exception& ex) {
16331615
if (dynamic_cast<const ov::NotImplemented*>(&ex) != nullptr) {
1634-
const auto currErrorMess = getExceptionDescWithoutStatus(ex);
1616+
const std::string currErrorMess = ex.what();
16351617
if (!currErrorMess.empty()) {
16361618
errorMessage += errorMessage.empty() ? currErrorMess : "\n" + currErrorMess;
16371619
}

src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp

+49
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,48 @@ struct ConvertFrom4BitPrecision<std::tuple<src_t, dst_t>> {
893893
}
894894
};
895895

896+
#define INTEL_CPU_CVT_TO_4BIT_LIST \
897+
INTEL_CPU_CVT(f32, nf4), INTEL_CPU_CVT(f16, nf4), INTEL_CPU_CVT(bf16, nf4)
898+
899+
struct ConvertTo4BitContext {
900+
ov::element::Type_t outType;
901+
const void* srcPtr;
902+
void* dstPtr;
903+
size_t size;
904+
bool converted;
905+
};
906+
907+
template <typename T>
908+
struct ConvertTo4BitPrecision;
909+
910+
template <typename src_t, typename dst_t>
911+
struct ConvertTo4BitPrecision<std::tuple<src_t, dst_t>> {
912+
void operator()(ConvertTo4BitContext& ctx) {
913+
auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t {
914+
uint8_t shift = high_half ? 4 : 0;
915+
return dst | (uint8_t) (val << shift);
916+
};
917+
918+
auto src = static_cast<const src_t*>(ctx.srcPtr);
919+
auto dst = static_cast<uint8_t*>(ctx.dstPtr);
920+
// each byte must be fully processed within same thread
921+
auto work_amount = div_up(ctx.size, 2);
922+
if (ctx.outType == ov::element::nf4) {
923+
parallel_for(work_amount, [&](size_t ib) {
924+
for (int i = 0; i < 2; i++) {
925+
int idx = ib * 2 + i;
926+
uint8_t val = idx % 2 == 0 ? 0 : dst[idx / 2];
927+
val = insert_half_byte(val, ConvertNF4::quantize(static_cast<float>(src[idx])), idx % 2);
928+
dst[idx / 2] = val;
929+
}
930+
});
931+
} else {
932+
OPENVINO_THROW("cpu_convert doesn't support output data type: ", ctx.outType, ". Not implemented.");
933+
}
934+
ctx.converted = true;
935+
}
936+
};
937+
896938
#define INTEL_CPU_CVT_FROM_BYTE_FP_LIST \
897939
INTEL_CPU_CVT(f8e8m0, f32), INTEL_CPU_CVT(f8e8m0, bf16), INTEL_CPU_CVT(f8e8m0, f16)
898940

@@ -1017,6 +1059,12 @@ void cpu_convert(const void* srcPtr,
10171059
if (!ctx.converted) {
10181060
OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc);
10191061
}
1062+
} else if (dstPrc.bitwidth() == 4u) {
1063+
ConvertTo4BitContext ctx{dstPrc, srcPtr, dstPtr, size, false};
1064+
OV_SWITCH(intel_cpu, ConvertTo4BitPrecision, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
1065+
if (!ctx.converted) {
1066+
OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc);
1067+
}
10201068
} else if (srcPrc == ov::element::f8e8m0) {
10211069
ConvertFromByteFPContext ctx{srcPrc, srcPtr, dstPtr, size, false};
10221070
OV_SWITCH(intel_cpu,
@@ -1063,6 +1111,7 @@ bool is_supported_convert(ov::element::Type srcPrc, ov::element::Type dstPrc) {
10631111
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BIN_LIST);
10641112
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_4BIT_LIST);
10651113
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BYTE_FP_LIST);
1114+
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
10661115
return ctx.isSupported;
10671116
}
10681117

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp

+39-2
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,16 @@ void ConvertCPULayerTest::generate_inputs(const std::vector<ov::Shape>& targetIn
151151
const auto& funcInputs = function->inputs();
152152
for (size_t i = 0; i < funcInputs.size(); ++i) {
153153
const auto& funcInput = funcInputs[i];
154-
ov::Tensor tensor =
155-
ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
154+
ov::Tensor tensor;
155+
if (outPrc == ov::element::nf4) {
156+
tensor = ov::test::utils::create_and_fill_tensor_real_distribution(funcInput.get_element_type(),
157+
targetInputStaticShapes[i],
158+
-1.f,
159+
1.f,
160+
1);
161+
} else {
162+
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
163+
}
156164
if (special_value != ov::test::SpecialValue::none) {
157165
if (inPrc == ov::element::f32) {
158166
modify_value<float>(tensor, special_value);
@@ -176,6 +184,35 @@ void ConvertCPULayerTest::validate_out_prc() const {
176184
FAIL() << "ConvertCPULayerTest supports only non boolean output prc";
177185
}
178186

187+
void ConvertCPULayerTest::validate() {
188+
if (outPrc == ov::element::nf4) {
189+
// Use custom bit-exact validation, because common tests infra doesn't support 4bits tensors comparision
190+
auto div_up = [&](auto a, auto b) {
191+
assert(b);
192+
return (a + b - 1) / b;
193+
};
194+
195+
auto actualOutputs = get_plugin_outputs();
196+
auto expectedOutputs = calculate_refs();
197+
ASSERT_EQ(expectedOutputs.size(), actualOutputs.size());
198+
ASSERT_EQ(expectedOutputs.size(), 1);
199+
ASSERT_EQ(expectedOutputs[0].get_shape(), actualOutputs[0].get_shape());
200+
201+
auto expected_data = reinterpret_cast<const uint8_t*>(expectedOutputs[0].data());
202+
auto actual_data = reinterpret_cast<const uint8_t*>(actualOutputs[0].data());
203+
size_t shape_size_cnt = div_up(shape_size(expectedOutputs[0].get_shape()), 2);
204+
for (size_t i = 0; i < shape_size_cnt; ++i) {
205+
uint8_t expected_value = expected_data[i];
206+
uint8_t actual_value = actual_data[i];
207+
ASSERT_EQ(expected_value, actual_value);
208+
}
209+
210+
return;
211+
}
212+
213+
SubgraphBaseTest::validate();
214+
}
215+
179216
void ConvertToBooleanCPULayerTest::validate_out_prc() const {
180217
if (outPrc != ov::element::boolean)
181218
FAIL() << "ConvertToBooleanCPULayerTest supports only boolean output prc";

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ConvertCPULayerTest : public testing::WithParamInterface<convertLayerTestP
2929
protected:
3030
void SetUp() override;
3131
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
32+
void validate() override;
3233
virtual void validate_out_prc() const;
3334

3435
ov::element::Type inPrc, outPrc;

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/common/conversion.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ const std::vector<ov::element::Type> float_precisions = {
6464
ov::element::bf16,
6565
};
6666

67+
INSTANTIATE_TEST_SUITE_P(smoke_ConvertCPULayerTest_float_to_nf4, ConvertCPULayerTest,
68+
::testing::Combine(
69+
::testing::ValuesIn(inShapes_4D_dynamic()),
70+
::testing::ValuesIn(float_precisions),
71+
::testing::Values(ov::element::nf4),
72+
::testing::Values(ov::test::SpecialValue::none),
73+
::testing::Values(CPUSpecificParams({nchw}, {nchw}, {}, {"ref"}))),
74+
ConvertCPULayerTest::getTestCaseName);
75+
6776
const std::vector<ov::element::Type> f8_precisions = {
6877
ov::element::f8e4m3,
6978
ov::element::f8e5m2,

0 commit comments

Comments
 (0)