Skip to content

Commit f76b288

Browse files
dmitry-gorokhovalvoron
authored andcommitted
[CPU][ARM] Fixed JIT Eltwise precision limitations
1 parent df40180 commit f76b288

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

src/plugins/intel_cpu/src/node.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ class Node {
398398
return mergedWith;
399399
}
400400

401-
const std::vector<NodePtr>& getFusedWith() {
401+
const std::vector<NodePtr>& getFusedWith() const {
402402
return fusedWith;
403403
}
404404

src/plugins/intel_cpu/src/nodes/eltwise.cpp

+35-21
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,8 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
746746
const float alpha,
747747
const float beta,
748748
const float gamma,
749-
const std::vector<ov::element::Type>& input_precisions = {}) {
749+
const std::vector<ov::element::Type>& input_precisions = {},
750+
const std::vector<ov::element::Type>& output_precisions = {}) {
750751
#if defined(OPENVINO_ARCH_X86_64)
751752
const auto isISASupportedByJIT = mayiuse(dnnl::impl::cpu::x64::sse41);
752753
#elif defined(OPENVINO_ARCH_ARM64)
@@ -788,15 +789,27 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
788789
return false;
789790
}
790791

791-
const std::set<ov::element::Type> supported_precisions =
792+
std::set<ov::element::Type> supported_input_precisions = std::set<ov::element::Type>{ov::element::f16,
793+
ov::element::f32,
794+
ov::element::i32,
795+
ov::element::i8,
796+
ov::element::u8};
797+
798+
std::set<ov::element::Type> supported_output_precisions = supported_input_precisions;
799+
if (one_of(algorithm, Algorithm::EltwiseDivide, Algorithm::EltwiseFloor)) {
800+
supported_input_precisions = std::set<ov::element::Type>{ov::element::f16, ov::element::f32};
801+
}
802+
803+
auto fusedOps = node->getFusedWith();
804+
if (!fusedOps.empty()) {
792805
// Divide and Floor (issue #138629) operations are supported for fp32 and fp16 only.
793-
((algorithm == Algorithm::EltwiseDivide) || (algorithm == Algorithm::EltwiseFloor))
794-
? std::set<ov::element::Type>{ov::element::f16, ov::element::f32}
795-
: std::set<ov::element::Type>{ov::element::f16,
796-
ov::element::f32,
797-
ov::element::i32,
798-
ov::element::i8,
799-
ov::element::u8};
806+
if (one_of(fusedOps.back()->getAlgorithm(), Algorithm::EltwiseDivide, Algorithm::EltwiseFloor)) {
807+
supported_output_precisions = std::set<ov::element::Type>{ov::element::f16, ov::element::f32};
808+
}
809+
} else {
810+
supported_output_precisions = supported_input_precisions;
811+
}
812+
800813
#elif defined(OPENVINO_ARCH_RISCV64)
801814
if (!one_of(algorithm,
802815
Algorithm::EltwiseAdd,
@@ -813,36 +826,37 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
813826
return false;
814827
}
815828

816-
const std::set<ov::element::Type> supported_precisions = {ov::element::f32,
817-
ov::element::i32,
818-
ov::element::i8,
819-
ov::element::u8};
829+
const std::set<ov::element::Type> supported_input_precisions = {ov::element::f32,
830+
ov::element::i32,
831+
ov::element::i8,
832+
ov::element::u8};
833+
auto supported_input_precisions = supported_output_precisions;
820834
#endif
821835

822836
#if defined(OPENVINO_ARCH_ARM64) || defined(OPENVINO_ARCH_RISCV64)
823-
const auto check_precisions = [](const std::vector<ov::element::Type>& input_precisions,
824-
const std::vector<ov::element::Type>& output_precisions,
825-
const std::set<ov::element::Type>& supported_precisions) {
837+
const auto check_precisions = [&](const std::vector<ov::element::Type>& input_precisions,
838+
const std::vector<ov::element::Type>& output_precisions) {
826839
if (std::any_of(input_precisions.begin(),
827840
input_precisions.end(),
828-
[&supported_precisions](const ov::element::Type& precision) {
829-
return supported_precisions.find(precision) == supported_precisions.end();
841+
[&supported_input_precisions](const ov::element::Type& precision) {
842+
return supported_input_precisions.find(precision) == supported_input_precisions.end();
830843
})) {
831844
return false;
832845
}
833846

834847
if (std::any_of(output_precisions.begin(),
835848
output_precisions.end(),
836-
[&supported_precisions](const ov::element::Type& precision) {
837-
return supported_precisions.find(precision) == supported_precisions.end();
849+
[&supported_output_precisions](const ov::element::Type& precision) {
850+
return supported_output_precisions.find(precision) == supported_output_precisions.end();
838851
})) {
839852
return false;
840853
}
841854

842855
return true;
843856
};
844857

845-
return check_precisions(input_precisions, node->getOriginalOutputPrecisions(), supported_precisions);
858+
auto out_precisions = output_precisions.empty() ? node->getOriginalOutputPrecisions() : output_precisions;
859+
return check_precisions(input_precisions, out_precisions);
846860
#endif
847861

848862
// Unsupported architectures should return false:

0 commit comments

Comments
 (0)